From 84e588f3c473a56ac0ee736384c060c1e1f5f44a Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 6 Oct 2024 14:00:41 +0200 Subject: [PATCH 1/6] Add helper utilities Adds two helper utilities for converting a structure to map[string]any as well as printing as JSON. --- pkg/backend/utils.go | 47 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 pkg/backend/utils.go diff --git a/pkg/backend/utils.go b/pkg/backend/utils.go new file mode 100644 index 0000000..2db3a64 --- /dev/null +++ b/pkg/backend/utils.go @@ -0,0 +1,47 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "encoding/json" + "fmt" +) + +// ToMap converts the given value to a map[string]any. This is useful for working with JSON data. +func ToMap(v any) (map[string]any, error) { + data, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("error marshaling to JSON: %v", err) + } + + var mapResult map[string]any + err = json.Unmarshal(data, &mapResult) + if err != nil { + return nil, fmt.Errorf("error unmarshaling JSON to map: %v", err) + } + + return mapResult, nil +} + +// PrintJSON prints the given value as a JSON string. Useful for debugging. +func PrintJSON(v any) { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + fmt.Printf("error marshaling to JSON: %v", err) + return + } + + fmt.Println(string(data)) +} From 72233571bd1288911df9e9ca3daccf273a61e747 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 6 Oct 2024 14:01:22 +0200 Subject: [PATCH 2/6] Data structures and methods for executing tools Both OpenAI Ollama have the same API structure for executing tools. This commit adds a tool registry that the library user can use to register a new tool that the model will be able to call. --- pkg/backend/tools.go | 90 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 pkg/backend/tools.go diff --git a/pkg/backend/tools.go b/pkg/backend/tools.go new file mode 100644 index 0000000..b8eba9e --- /dev/null +++ b/pkg/backend/tools.go @@ -0,0 +1,90 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "fmt" + "sync" +) + +// ErrToolNotFound is returned when a tool is not found in the registry. +var ErrToolNotFound = fmt.Errorf("tool not found") + +// ToolWrapper is a function type that wraps a tool's functionality. +type ToolWrapper func(args map[string]any) (string, error) + +// Tool represents a tool that can be executed. +type Tool struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +// ToolFunction represents the function signature of a tool. +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` + Wrapper ToolWrapper `json:"-"` +} + +// ToolRegistry manages the registration of tools and their corresponding wrapper functions. +type ToolRegistry struct { + tools map[string]Tool + m sync.Mutex +} + +// NewToolRegistry initializes a new ToolRegistry. +func NewToolRegistry() *ToolRegistry { + return &ToolRegistry{ + tools: make(map[string]Tool), + } +} + +// RegisterTool allows the registration of a tool by name, expected parameters, and a wrapper function. +func (r *ToolRegistry) RegisterTool(t Tool) { + r.m.Lock() + r.tools[t.Function.Name] = t + r.m.Unlock() +} + +// ToolsMap returns a list of tools as a map of string to any. This is the format that both Ollama and OpenAI expect. +func (r *ToolRegistry) ToolsMap() ([]map[string]any, error) { + toolList := make([]map[string]any, 0, len(r.tools)) + r.m.Lock() + for _, tool := range r.tools { + tMap, err := ToMap(tool) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list to map: %w", err) + } + toolList = append(toolList, tMap) + } + r.m.Unlock() + + return toolList, nil +} + +// ExecuteTool looks up a tool by name, checks the provided arguments, and calls the registered wrapper function. +func (r *ToolRegistry) ExecuteTool(toolName string, args map[string]any) (string, error) { + r.m.Lock() + defer r.m.Unlock() + + toolEntry, exists := r.tools[toolName] + if !exists { + return "", fmt.Errorf("%w: %s", ErrToolNotFound, toolName) + } + + // Call the tool's wrapper function with the provided arguments + return toolEntry.Function.Wrapper(args) +} From 7f15e81de420d9104022d59bc9b837f273d852c0 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 6 Oct 2024 15:16:01 +0200 Subject: [PATCH 3/6] Implement the Converse method with tool support for Ollama Adds a Converse method for Ollama. Due to how Ollama models are implemented (https://github.com/ollama/ollama/issues/6127) we try to detect if a non-existing tool was called and just route the prompt to the model again with tools disabled. --- pkg/backend/backend.go | 1 + pkg/backend/ollama_backend.go | 156 ++++++++++++++++++++++++++--- pkg/backend/ollama_backend_test.go | 2 +- 3 files changed, 142 insertions(+), 17 deletions(-) diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index 0cf7e15..5b4fc3d 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package backend import "context" diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 434b4db..ca4829e 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -17,6 +17,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -24,6 +25,7 @@ import ( ) const ( + chatEndpoint = "/api/chat" generateEndpoint = "/api/generate" embedEndpoint = "/api/embeddings" defaultTimeout = 30 * time.Second @@ -36,20 +38,39 @@ type OllamaBackend struct { BaseURL string } -// Response represents the structure of the response received from the Ollama API. -type Response struct { - Model string `json:"model"` - CreatedAt string `json:"created_at"` - Response string `json:"response"` - Done bool `json:"done"` - DoneReason string `json:"done_reason"` - Context []int `json:"context"` - TotalDuration int64 `json:"total_duration"` - LoadDuration int64 `json:"load_duration"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration int64 `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration int64 `json:"eval_duration"` +// OllamaResponse represents the structure of the response received from the Ollama API. +type OllamaResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + Context []int `json:"context"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration int64 `json:"eval_duration"` + Message OllamaResponseMessage `json:"message"` +} + +// OllamaResponseMessage represents the message part of the response from the Ollama API. +type OllamaResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OllamaToolCall `json:"tool_calls"` +} + +// OllamaToolCall represents a tool call to be made by the Ollama API. +type OllamaToolCall struct { + Function OllamaFunctionCall `json:"function"` +} + +// OllamaFunctionCall represents a function call to be made by the Ollama API. +type OllamaFunctionCall struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` } // OllamaEmbeddingResponse represents the response from the Ollama API for embeddings. @@ -68,6 +89,109 @@ func NewOllamaBackend(baseURL, model string, timeout time.Duration) *OllamaBacke } } +type ollamaConversationOption struct { + disableTools bool +} + +// Converse drives a conversation with the Ollama API based on the given conversation context. +func (o *OllamaBackend) Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error) { + resp, err := o.converseRoundTrip(ctx, prompt, ollamaConversationOption{}) + if errors.Is(err, ErrToolNotFound) { + // retry without tools if the error is due to a tool not being found + return o.converseRoundTrip(ctx, prompt, ollamaConversationOption{disableTools: true}) + } + + return resp, err +} + +func (o *OllamaBackend) converseRoundTrip(ctx context.Context, prompt *Prompt, opts ollamaConversationOption) (PromptResponse, error) { + msgMap, err := prompt.AsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert messages to map: %w", err) + } + + url := o.BaseURL + chatEndpoint + reqBody := map[string]any{ + "model": o.Model, + "messages": msgMap, + "stream": false, + } + + if !opts.disableTools { + toolMap, err := prompt.Tools.ToolsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert tools to map: %w", err) + } + reqBody["tools"] = toolMap + } + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := o.Client.Do(req) + if err != nil { + return PromptResponse{}, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return PromptResponse{}, fmt.Errorf("failed to generate response from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + } + + var result OllamaResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return PromptResponse{}, fmt.Errorf("failed to decode response: %w", err) + } + + if len(result.Message.ToolCalls) == 0 { + prompt.AddMessage("assistant", result.Message.Content) + return PromptResponse{ + Role: "assistant", + Content: result.Message.Content, + }, nil + } + + response := PromptResponse{ + Role: "tool", + ToolCalls: make([]ToolCall, 0, len(result.Message.ToolCalls)), + } + for _, toolCall := range result.Message.ToolCalls { + toolName := toolCall.Function.Name + toolArgs := toolCall.Function.Arguments + + toolResponse, err := prompt.Tools.ExecuteTool(toolName, toolArgs) + if errors.Is(err, ErrToolNotFound) { + // this is a bit of a hack. Ollama models always reply with tool calls and hallucinate + // the tool names if tools are given in the request, but the request is not actually + // tied to any tool. So we just ignore these and re-send the request with tools disabled + return PromptResponse{}, ErrToolNotFound + } else if err != nil { + return PromptResponse{}, fmt.Errorf("failed to execute tool: %w", err) + } + prompt.AddMessage("tool", toolResponse) + + response.ToolCalls = append(response.ToolCalls, ToolCall{ + Function: FunctionCall{ + Name: toolName, + Arguments: toolArgs, + Result: toolResponse, + }, + }) + } + + return response, nil + +} + // Generate produces a response from the Ollama API based on the given structured prompt. // // Parameters: @@ -129,8 +253,8 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, e ) } - var result Response - if err := json.NewDecoder(bytes.NewBuffer(bodyBytes)).Decode(&result); err != nil { + var result OllamaResponse + if err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(&result); err != nil { return "", fmt.Errorf("failed to decode response: %w", err) } diff --git a/pkg/backend/ollama_backend_test.go b/pkg/backend/ollama_backend_test.go index eec5bbf..aa712bf 100644 --- a/pkg/backend/ollama_backend_test.go +++ b/pkg/backend/ollama_backend_test.go @@ -28,7 +28,7 @@ const testEmbeddingText = "Test embedding text." func TestOllamaGenerate(t *testing.T) { t.Parallel() // Mock response from Ollama API - mockResponse := Response{ + mockResponse := OllamaResponse{ Model: "test-model", CreatedAt: time.Now().Format(time.RFC3339), Response: "This is a test response from Ollama.", From 9561cd32ac76df8455b163110db72eb6a8fed1d3 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 6 Oct 2024 14:36:57 +0200 Subject: [PATCH 4/6] Implement the Converse method with tool support for OpenAI Implements conversationw with tool support for OpenAI. The OpenAI models expect an assistant message that includes the details about the function call and the function call ID in the response. --- pkg/backend/openai_backend.go | 136 +++++++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 2 deletions(-) diff --git a/pkg/backend/openai_backend.go b/pkg/backend/openai_backend.go index aea6bcf..044d378 100644 --- a/pkg/backend/openai_backend.go +++ b/pkg/backend/openai_backend.go @@ -85,8 +85,9 @@ type OpenAIResponse struct { Choices []struct { Index int `json:"index"` Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` } `json:"message"` FinishReason string `json:"finish_reason"` } `json:"choices"` @@ -97,6 +98,137 @@ type OpenAIResponse struct { } `json:"usage"` } +// OpenAIResponseMessage represents the message part of the response from the OpenAI API. +type OpenAIResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` +} + +// OpenAIToolCall represents the structure of a tool call made by the assistant. +type OpenAIToolCall struct { + ID string `json:"id"` // The unique ID of the tool call. + Type string `json:"type"` // The type of tool call (e.g., "function"). + Function OpenAIToolFunction `json:"function"` // The function being called. +} + +// OpenAIToolFunction represents the function call made within a tool call. +type OpenAIToolFunction struct { + Name string `json:"name"` // The name of the function being invoked. + Arguments string `json:"arguments"` // The JSON string containing the arguments for the function. +} + +// Converse sends a series of messages to the OpenAI API and returns the generated response. +func (o *OpenAIBackend) Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error) { + msgMap, err := prompt.AsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert messages to map: %w", err) + } + + toolMap, err := prompt.Tools.ToolsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert tools to map: %w", err) + } + url := o.BaseURL + "/v1/chat/completions" + reqBody := map[string]any{ + "model": o.Model, + "messages": msgMap, + "stream": false, + "tools": toolMap, + } + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.APIKey) + + resp, err := o.HTTPClient.Do(req) + if err != nil { + return PromptResponse{}, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return PromptResponse{}, fmt.Errorf("failed to generate response from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + } + + var result OpenAIResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return PromptResponse{}, fmt.Errorf("failed to decode response: %w", err) + } + + if len(result.Choices[0].Message.ToolCalls) == 0 { + prompt.AddMessage("assistant", result.Choices[0].Message.Content) + return PromptResponse{ + Role: "assistant", + Content: result.Choices[0].Message.Content, + }, nil + } + + response := PromptResponse{ + Role: "tool", + ToolCalls: make([]ToolCall, 0, len(result.Choices[0].Message.ToolCalls)), + } + for _, toolCall := range result.Choices[0].Message.ToolCalls { + toolName := toolCall.Function.Name + toolArgs := toolCall.Function.Arguments + + var parsedArgs map[string]interface{} + err = json.Unmarshal([]byte(toolArgs), &parsedArgs) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to unmarshal tool arguments: %w", err) + } + + toolResponse, err := prompt.Tools.ExecuteTool(toolName, parsedArgs) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to execute tool: %w", err) + } + + // we also need to add the previous reply with the call ID to the conversation + // todo: programatically convert + prompt.AppendMessage(Message{ + Role: "assistant", + Fields: map[string]any{ + "type": result.Choices[0].Message.ToolCalls[0].Type, + "tool_calls": []map[string]any{ + { + "id": result.Choices[0].Message.ToolCalls[0].ID, + "type": result.Choices[0].Message.ToolCalls[0].Type, + "function": map[string]any{ + "name": result.Choices[0].Message.ToolCalls[0].Function.Name, + "arguments": result.Choices[0].Message.ToolCalls[0].Function.Arguments, + }, + }, + }, + }, + }) + prompt.AppendMessage(Message{ + Role: "tool", + Content: toolResponse, + Fields: map[string]any{ + "tool_call_id": result.Choices[0].Message.ToolCalls[0].ID, + }, + }) + + response.ToolCalls = append(response.ToolCalls, ToolCall{ + Function: FunctionCall{ + Name: toolName, + Arguments: parsedArgs, + Result: toolResponse, + }, + }) + } + return response, nil +} + // Generate sends a prompt to the OpenAI API and returns the generated response. // // Parameters: From c23c46557a808aeb958703ae7afe28ce22dd26c5 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 6 Oct 2024 15:15:43 +0200 Subject: [PATCH 5/6] Add the Converse method to the back end interface Adds Converse as a back end method, making it usable by library clients. --- pkg/backend/backend.go | 60 ++++++++++++++++++++++++++++-- pkg/backend/openai_backend_test.go | 10 +++-- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index 5b4fc3d..feeb3ce 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -14,18 +14,22 @@ package backend -import "context" +import ( + "context" +) // Backend defines the interface for interacting with various LLM backends. type Backend interface { + Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error) Generate(ctx context.Context, prompt *Prompt) (string, error) Embed(ctx context.Context, input string) ([]float32, error) } // Message represents a single role-based message in the conversation. type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + Fields map[string]any `json:"fields,omitempty"` } // Parameters defines generation settings for LLM completions. @@ -41,11 +45,17 @@ type Parameters struct { type Prompt struct { Messages []Message `json:"messages"` Parameters Parameters `json:"parameters"` + // ToolRegistry is a map of tool names to their corresponding wrapper functions. + Tools *ToolRegistry } // NewPrompt creates and returns a new Prompt. func NewPrompt() *Prompt { - return &Prompt{} + return &Prompt{ + Messages: make([]Message, 0), + Parameters: Parameters{}, + Tools: NewToolRegistry(), + } } // AddMessage adds a message with a specific role to the prompt. @@ -54,8 +64,50 @@ func (p *Prompt) AddMessage(role, content string) *Prompt { return p } +// AppendMessage adds a message with a specific role to the prompt. +func (p *Prompt) AppendMessage(msg Message) *Prompt { + p.Messages = append(p.Messages, msg) + return p +} + // SetParameters sets the generation parameters for the prompt. func (p *Prompt) SetParameters(params Parameters) *Prompt { p.Parameters = params return p } + +// AsMap returns the conversation's messages as a list of maps. +func (p *Prompt) AsMap() ([]map[string]any, error) { + messageList := make([]map[string]any, 0, len(p.Messages)) + for _, message := range p.Messages { + msgMap := map[string]any{ + "role": message.Role, + "content": message.Content, + } + for k, v := range message.Fields { + msgMap[k] = v + } + messageList = append(messageList, msgMap) + } + + return messageList, nil +} + +// FunctionCall represents a call to a function. +type FunctionCall struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Result any `json:"result"` +} + +// ToolCall represents a call to a tool. +type ToolCall struct { + Function FunctionCall `json:"function"` +} + +// PromptResponse represents a response from the AI in a conversation. +type PromptResponse struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} diff --git a/pkg/backend/openai_backend_test.go b/pkg/backend/openai_backend_test.go index e088a2d..1ad13e2 100644 --- a/pkg/backend/openai_backend_test.go +++ b/pkg/backend/openai_backend_test.go @@ -33,16 +33,18 @@ func TestGenerate(t *testing.T) { Choices: []struct { Index int `json:"index"` Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` } `json:"message"` FinishReason string `json:"finish_reason"` }{ { Index: 0, Message: struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` }{ Role: "assistant", Content: "This is a test response.", From 1427fb64ce70a0e6fd62a27cad9bbce254d66d5a Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 6 Oct 2024 14:37:11 +0200 Subject: [PATCH 6/6] Add an example with a tool pretending to do weather reports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an example of tool function with pretend weather reports. Example invocations: ``` BACKEND=ollama go run ./examples/tools/main.go "What is the weather in London?" 2024/10/17 22:42:11 No model selected with the MODEL env variable. Defaulting to qwen2.5 2024/10/17 22:42:11 Using Ollama backend: qwen2.5 2024/10/17 22:42:14 Tool called 2024/10/17 22:42:14 Response: 2024/10/17 22:42:14 The current temperature in London is 15°C and the conditions are rainy. ``` --- examples/tools/main.go | 118 +++++++++++++++++++++++++ examples/tools/weather/weather_tool.go | 79 +++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 examples/tools/main.go create mode 100644 examples/tools/weather/weather_tool.go diff --git a/examples/tools/main.go b/examples/tools/main.go new file mode 100644 index 0000000..7e5df5e --- /dev/null +++ b/examples/tools/main.go @@ -0,0 +1,118 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "log" + "os" + "strings" + "time" + + "github.com/stackloklabs/gollm/examples/tools/weather" + "github.com/stackloklabs/gollm/pkg/backend" +) + +var ( + ollamaHost = "http://localhost:11434" + ollamaGenModel = "qwen2.5" + openaiModel = "gpt-4o-mini" +) + +const ( + systemMessage = ` +You are an AI assistant that can retrieve weather forecasts by calling a tool +that returns weather data in JSON format. You will be provided with a city +name, and you will use a tool to call out to a weather forecast API that +provides the weather for that city. The tool returns a JSON object with three +fields: city, temperature, and conditions. +` + summarizeMessage = ` +Summarize the tool's forecast of the weather in clear, plain language for the user. +` +) + +func main() { + var generationBackend backend.Backend + + beSelection := os.Getenv("BACKEND") + if beSelection == "" { + log.Println("No backend selected with the BACKEND env variable. Defaulting to Ollama.") + beSelection = "ollama" + } + modelSelection := os.Getenv("MODEL") + if modelSelection == "" { + switch beSelection { + case "ollama": + modelSelection = ollamaGenModel + case "openai": + modelSelection = openaiModel + } + log.Println("No model selected with the MODEL env variable. Defaulting to ", modelSelection) + } + + switch beSelection { + case "ollama": + generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel, 30*time.Second) + log.Println("Using Ollama backend: ", ollamaGenModel) + case "openai": + openaiKey := os.Getenv("OPENAI_API_KEY") + if openaiKey == "" { + log.Fatalf("OPENAI_API_KEY is required for OpenAI backend") + } + generationBackend = backend.NewOpenAIBackend(openaiKey, openaiModel, 30*time.Second) + log.Println("Using OpenAI backend: ", openaiModel) + default: + log.Fatalf("Unknown backend: %s", beSelection) + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + userPrompt := os.Args[1:] + if len(userPrompt) == 0 { + log.Fatalf("Please provide a prompt") + } + + convo := backend.NewPrompt() + convo.Tools.RegisterTool(weather.Tool()) + // start the conversation. We add a system message to tune the output + // and add the weather tool to the conversation so that the model knows to call it. + convo.AddMessage("system", systemMessage) + convo.AddMessage("user", strings.Join(userPrompt, " ")) + + // generate the response + resp, err := generationBackend.Converse(ctx, convo) + if err != nil { + log.Fatalf("Error generating response: %v", err) + } + + if len(resp.ToolCalls) == 0 { + log.Println("No tool calls in response.") + log.Println("Response:", convo.Messages[len(convo.Messages)-1].Content) + return + } + + log.Println("Tool called") + + // if there was a tool response, first just feed it back to the model so it makes sense of it + _, err = generationBackend.Converse(ctx, convo) + if err != nil { + log.Fatalf("Error generating response: %v", err) + } + + log.Println("Response:") + log.Println(convo.Messages[len(convo.Messages)-1].Content) +} diff --git a/examples/tools/weather/weather_tool.go b/examples/tools/weather/weather_tool.go new file mode 100644 index 0000000..0a833ad --- /dev/null +++ b/examples/tools/weather/weather_tool.go @@ -0,0 +1,79 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package weather + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/stackloklabs/gollm/pkg/backend" +) + +// Tool returns a backend.Tool object that can be used to interact with the weather tool. +func Tool() backend.Tool { + return backend.Tool{ + Type: "function", + Function: backend.ToolFunction{ + Name: "weather", + Description: "Get weather report for a city", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{ + "type": "string", + "description": "The city for which to get the weather report", + }, + }, + "required": []string{"city"}, + }, + Wrapper: weatherReportWrapper, + }, + } +} + +func weatherReportWrapper(params map[string]any) (string, error) { + city, ok := params["city"].(string) + if !ok { + return "", fmt.Errorf("city must be a string") + } + return weatherReport(city) +} + +// WeatherReport defines the structure of the JSON response +type WeatherReport struct { + City string `json:"city"` + Temperature string `json:"temperature"` + Conditions string `json:"conditions"` +} + +// weatherReport returns a dummy weather report for the specified city in JSON format. +func weatherReport(city string) (string, error) { + // in a real application, this data would be fetched from an external API + weatherData := map[string]WeatherReport{ + "London": {City: "London", Temperature: "15°C", Conditions: "Rainy"}, + "Stockholm": {City: "Stockholm", Temperature: "10°C", Conditions: "Sunny"}, + "Brno": {City: "Brno", Temperature: "18°C", Conditions: "Clear skies"}, + } + + if report, ok := weatherData[city]; ok { + jsonReport, err := json.Marshal(report) + if err != nil { + return "", err + } + return string(jsonReport), nil + } + + return "", errors.New("city not found") +}