diff --git a/examples/tools/main.go b/examples/tools/main.go new file mode 100644 index 0000000..7e5df5e --- /dev/null +++ b/examples/tools/main.go @@ -0,0 +1,118 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "log" + "os" + "strings" + "time" + + "github.com/stackloklabs/gollm/examples/tools/weather" + "github.com/stackloklabs/gollm/pkg/backend" +) + +var ( + ollamaHost = "http://localhost:11434" + ollamaGenModel = "qwen2.5" + openaiModel = "gpt-4o-mini" +) + +const ( + systemMessage = ` +You are an AI assistant that can retrieve weather forecasts by calling a tool +that returns weather data in JSON format. You will be provided with a city +name, and you will use a tool to call out to a weather forecast API that +provides the weather for that city. The tool returns a JSON object with three +fields: city, temperature, and conditions. +` + summarizeMessage = ` +Summarize the tool's forecast of the weather in clear, plain language for the user. +` +) + +func main() { + var generationBackend backend.Backend + + beSelection := os.Getenv("BACKEND") + if beSelection == "" { + log.Println("No backend selected with the BACKEND env variable. Defaulting to Ollama.") + beSelection = "ollama" + } + modelSelection := os.Getenv("MODEL") + if modelSelection == "" { + switch beSelection { + case "ollama": + modelSelection = ollamaGenModel + case "openai": + modelSelection = openaiModel + } + log.Println("No model selected with the MODEL env variable. Defaulting to ", modelSelection) + } + + switch beSelection { + case "ollama": + generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel, 30*time.Second) + log.Println("Using Ollama backend: ", ollamaGenModel) + case "openai": + openaiKey := os.Getenv("OPENAI_API_KEY") + if openaiKey == "" { + log.Fatalf("OPENAI_API_KEY is required for OpenAI backend") + } + generationBackend = backend.NewOpenAIBackend(openaiKey, openaiModel, 30*time.Second) + log.Println("Using OpenAI backend: ", openaiModel) + default: + log.Fatalf("Unknown backend: %s", beSelection) + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + userPrompt := os.Args[1:] + if len(userPrompt) == 0 { + log.Fatalf("Please provide a prompt") + } + + convo := backend.NewPrompt() + convo.Tools.RegisterTool(weather.Tool()) + // start the conversation. We add a system message to tune the output + // and add the weather tool to the conversation so that the model knows to call it. + convo.AddMessage("system", systemMessage) + convo.AddMessage("user", strings.Join(userPrompt, " ")) + + // generate the response + resp, err := generationBackend.Converse(ctx, convo) + if err != nil { + log.Fatalf("Error generating response: %v", err) + } + + if len(resp.ToolCalls) == 0 { + log.Println("No tool calls in response.") + log.Println("Response:", convo.Messages[len(convo.Messages)-1].Content) + return + } + + log.Println("Tool called") + + // if there was a tool response, first just feed it back to the model so it makes sense of it + _, err = generationBackend.Converse(ctx, convo) + if err != nil { + log.Fatalf("Error generating response: %v", err) + } + + log.Println("Response:") + log.Println(convo.Messages[len(convo.Messages)-1].Content) +} diff --git a/examples/tools/weather/weather_tool.go b/examples/tools/weather/weather_tool.go new file mode 100644 index 0000000..0a833ad --- /dev/null +++ b/examples/tools/weather/weather_tool.go @@ -0,0 +1,79 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package weather + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/stackloklabs/gollm/pkg/backend" +) + +// Tool returns a backend.Tool object that can be used to interact with the weather tool. +func Tool() backend.Tool { + return backend.Tool{ + Type: "function", + Function: backend.ToolFunction{ + Name: "weather", + Description: "Get weather report for a city", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{ + "type": "string", + "description": "The city for which to get the weather report", + }, + }, + "required": []string{"city"}, + }, + Wrapper: weatherReportWrapper, + }, + } +} + +func weatherReportWrapper(params map[string]any) (string, error) { + city, ok := params["city"].(string) + if !ok { + return "", fmt.Errorf("city must be a string") + } + return weatherReport(city) +} + +// WeatherReport defines the structure of the JSON response +type WeatherReport struct { + City string `json:"city"` + Temperature string `json:"temperature"` + Conditions string `json:"conditions"` +} + +// weatherReport returns a dummy weather report for the specified city in JSON format. +func weatherReport(city string) (string, error) { + // in a real application, this data would be fetched from an external API + weatherData := map[string]WeatherReport{ + "London": {City: "London", Temperature: "15°C", Conditions: "Rainy"}, + "Stockholm": {City: "Stockholm", Temperature: "10°C", Conditions: "Sunny"}, + "Brno": {City: "Brno", Temperature: "18°C", Conditions: "Clear skies"}, + } + + if report, ok := weatherData[city]; ok { + jsonReport, err := json.Marshal(report) + if err != nil { + return "", err + } + return string(jsonReport), nil + } + + return "", errors.New("city not found") +} diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index 0cf7e15..feeb3ce 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -11,20 +11,25 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package backend -import "context" +import ( + "context" +) // Backend defines the interface for interacting with various LLM backends. type Backend interface { + Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error) Generate(ctx context.Context, prompt *Prompt) (string, error) Embed(ctx context.Context, input string) ([]float32, error) } // Message represents a single role-based message in the conversation. type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + Fields map[string]any `json:"fields,omitempty"` } // Parameters defines generation settings for LLM completions. @@ -40,11 +45,17 @@ type Parameters struct { type Prompt struct { Messages []Message `json:"messages"` Parameters Parameters `json:"parameters"` + // ToolRegistry is a map of tool names to their corresponding wrapper functions. + Tools *ToolRegistry } // NewPrompt creates and returns a new Prompt. func NewPrompt() *Prompt { - return &Prompt{} + return &Prompt{ + Messages: make([]Message, 0), + Parameters: Parameters{}, + Tools: NewToolRegistry(), + } } // AddMessage adds a message with a specific role to the prompt. @@ -53,8 +64,50 @@ func (p *Prompt) AddMessage(role, content string) *Prompt { return p } +// AppendMessage adds a message with a specific role to the prompt. +func (p *Prompt) AppendMessage(msg Message) *Prompt { + p.Messages = append(p.Messages, msg) + return p +} + // SetParameters sets the generation parameters for the prompt. func (p *Prompt) SetParameters(params Parameters) *Prompt { p.Parameters = params return p } + +// AsMap returns the conversation's messages as a list of maps. +func (p *Prompt) AsMap() ([]map[string]any, error) { + messageList := make([]map[string]any, 0, len(p.Messages)) + for _, message := range p.Messages { + msgMap := map[string]any{ + "role": message.Role, + "content": message.Content, + } + for k, v := range message.Fields { + msgMap[k] = v + } + messageList = append(messageList, msgMap) + } + + return messageList, nil +} + +// FunctionCall represents a call to a function. +type FunctionCall struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Result any `json:"result"` +} + +// ToolCall represents a call to a tool. +type ToolCall struct { + Function FunctionCall `json:"function"` +} + +// PromptResponse represents a response from the AI in a conversation. +type PromptResponse struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 434b4db..ca4829e 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -17,6 +17,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -24,6 +25,7 @@ import ( ) const ( + chatEndpoint = "/api/chat" generateEndpoint = "/api/generate" embedEndpoint = "/api/embeddings" defaultTimeout = 30 * time.Second @@ -36,20 +38,39 @@ type OllamaBackend struct { BaseURL string } -// Response represents the structure of the response received from the Ollama API. -type Response struct { - Model string `json:"model"` - CreatedAt string `json:"created_at"` - Response string `json:"response"` - Done bool `json:"done"` - DoneReason string `json:"done_reason"` - Context []int `json:"context"` - TotalDuration int64 `json:"total_duration"` - LoadDuration int64 `json:"load_duration"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration int64 `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration int64 `json:"eval_duration"` +// OllamaResponse represents the structure of the response received from the Ollama API. +type OllamaResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + Context []int `json:"context"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration int64 `json:"eval_duration"` + Message OllamaResponseMessage `json:"message"` +} + +// OllamaResponseMessage represents the message part of the response from the Ollama API. +type OllamaResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OllamaToolCall `json:"tool_calls"` +} + +// OllamaToolCall represents a tool call to be made by the Ollama API. +type OllamaToolCall struct { + Function OllamaFunctionCall `json:"function"` +} + +// OllamaFunctionCall represents a function call to be made by the Ollama API. +type OllamaFunctionCall struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` } // OllamaEmbeddingResponse represents the response from the Ollama API for embeddings. @@ -68,6 +89,109 @@ func NewOllamaBackend(baseURL, model string, timeout time.Duration) *OllamaBacke } } +type ollamaConversationOption struct { + disableTools bool +} + +// Converse drives a conversation with the Ollama API based on the given conversation context. +func (o *OllamaBackend) Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error) { + resp, err := o.converseRoundTrip(ctx, prompt, ollamaConversationOption{}) + if errors.Is(err, ErrToolNotFound) { + // retry without tools if the error is due to a tool not being found + return o.converseRoundTrip(ctx, prompt, ollamaConversationOption{disableTools: true}) + } + + return resp, err +} + +func (o *OllamaBackend) converseRoundTrip(ctx context.Context, prompt *Prompt, opts ollamaConversationOption) (PromptResponse, error) { + msgMap, err := prompt.AsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert messages to map: %w", err) + } + + url := o.BaseURL + chatEndpoint + reqBody := map[string]any{ + "model": o.Model, + "messages": msgMap, + "stream": false, + } + + if !opts.disableTools { + toolMap, err := prompt.Tools.ToolsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert tools to map: %w", err) + } + reqBody["tools"] = toolMap + } + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := o.Client.Do(req) + if err != nil { + return PromptResponse{}, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return PromptResponse{}, fmt.Errorf("failed to generate response from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + } + + var result OllamaResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return PromptResponse{}, fmt.Errorf("failed to decode response: %w", err) + } + + if len(result.Message.ToolCalls) == 0 { + prompt.AddMessage("assistant", result.Message.Content) + return PromptResponse{ + Role: "assistant", + Content: result.Message.Content, + }, nil + } + + response := PromptResponse{ + Role: "tool", + ToolCalls: make([]ToolCall, 0, len(result.Message.ToolCalls)), + } + for _, toolCall := range result.Message.ToolCalls { + toolName := toolCall.Function.Name + toolArgs := toolCall.Function.Arguments + + toolResponse, err := prompt.Tools.ExecuteTool(toolName, toolArgs) + if errors.Is(err, ErrToolNotFound) { + // this is a bit of a hack. Ollama models always reply with tool calls and hallucinate + // the tool names if tools are given in the request, but the request is not actually + // tied to any tool. So we just ignore these and re-send the request with tools disabled + return PromptResponse{}, ErrToolNotFound + } else if err != nil { + return PromptResponse{}, fmt.Errorf("failed to execute tool: %w", err) + } + prompt.AddMessage("tool", toolResponse) + + response.ToolCalls = append(response.ToolCalls, ToolCall{ + Function: FunctionCall{ + Name: toolName, + Arguments: toolArgs, + Result: toolResponse, + }, + }) + } + + return response, nil + +} + // Generate produces a response from the Ollama API based on the given structured prompt. // // Parameters: @@ -129,8 +253,8 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, e ) } - var result Response - if err := json.NewDecoder(bytes.NewBuffer(bodyBytes)).Decode(&result); err != nil { + var result OllamaResponse + if err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(&result); err != nil { return "", fmt.Errorf("failed to decode response: %w", err) } diff --git a/pkg/backend/ollama_backend_test.go b/pkg/backend/ollama_backend_test.go index eec5bbf..aa712bf 100644 --- a/pkg/backend/ollama_backend_test.go +++ b/pkg/backend/ollama_backend_test.go @@ -28,7 +28,7 @@ const testEmbeddingText = "Test embedding text." func TestOllamaGenerate(t *testing.T) { t.Parallel() // Mock response from Ollama API - mockResponse := Response{ + mockResponse := OllamaResponse{ Model: "test-model", CreatedAt: time.Now().Format(time.RFC3339), Response: "This is a test response from Ollama.", diff --git a/pkg/backend/openai_backend.go b/pkg/backend/openai_backend.go index aea6bcf..044d378 100644 --- a/pkg/backend/openai_backend.go +++ b/pkg/backend/openai_backend.go @@ -85,8 +85,9 @@ type OpenAIResponse struct { Choices []struct { Index int `json:"index"` Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` } `json:"message"` FinishReason string `json:"finish_reason"` } `json:"choices"` @@ -97,6 +98,137 @@ type OpenAIResponse struct { } `json:"usage"` } +// OpenAIResponseMessage represents the message part of the response from the OpenAI API. +type OpenAIResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` +} + +// OpenAIToolCall represents the structure of a tool call made by the assistant. +type OpenAIToolCall struct { + ID string `json:"id"` // The unique ID of the tool call. + Type string `json:"type"` // The type of tool call (e.g., "function"). + Function OpenAIToolFunction `json:"function"` // The function being called. +} + +// OpenAIToolFunction represents the function call made within a tool call. +type OpenAIToolFunction struct { + Name string `json:"name"` // The name of the function being invoked. + Arguments string `json:"arguments"` // The JSON string containing the arguments for the function. +} + +// Converse sends a series of messages to the OpenAI API and returns the generated response. +func (o *OpenAIBackend) Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error) { + msgMap, err := prompt.AsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert messages to map: %w", err) + } + + toolMap, err := prompt.Tools.ToolsMap() + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to convert tools to map: %w", err) + } + url := o.BaseURL + "/v1/chat/completions" + reqBody := map[string]any{ + "model": o.Model, + "messages": msgMap, + "stream": false, + "tools": toolMap, + } + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.APIKey) + + resp, err := o.HTTPClient.Do(req) + if err != nil { + return PromptResponse{}, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return PromptResponse{}, fmt.Errorf("failed to generate response from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + } + + var result OpenAIResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return PromptResponse{}, fmt.Errorf("failed to decode response: %w", err) + } + + if len(result.Choices[0].Message.ToolCalls) == 0 { + prompt.AddMessage("assistant", result.Choices[0].Message.Content) + return PromptResponse{ + Role: "assistant", + Content: result.Choices[0].Message.Content, + }, nil + } + + response := PromptResponse{ + Role: "tool", + ToolCalls: make([]ToolCall, 0, len(result.Choices[0].Message.ToolCalls)), + } + for _, toolCall := range result.Choices[0].Message.ToolCalls { + toolName := toolCall.Function.Name + toolArgs := toolCall.Function.Arguments + + var parsedArgs map[string]interface{} + err = json.Unmarshal([]byte(toolArgs), &parsedArgs) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to unmarshal tool arguments: %w", err) + } + + toolResponse, err := prompt.Tools.ExecuteTool(toolName, parsedArgs) + if err != nil { + return PromptResponse{}, fmt.Errorf("failed to execute tool: %w", err) + } + + // we also need to add the previous reply with the call ID to the conversation + // todo: programatically convert + prompt.AppendMessage(Message{ + Role: "assistant", + Fields: map[string]any{ + "type": result.Choices[0].Message.ToolCalls[0].Type, + "tool_calls": []map[string]any{ + { + "id": result.Choices[0].Message.ToolCalls[0].ID, + "type": result.Choices[0].Message.ToolCalls[0].Type, + "function": map[string]any{ + "name": result.Choices[0].Message.ToolCalls[0].Function.Name, + "arguments": result.Choices[0].Message.ToolCalls[0].Function.Arguments, + }, + }, + }, + }, + }) + prompt.AppendMessage(Message{ + Role: "tool", + Content: toolResponse, + Fields: map[string]any{ + "tool_call_id": result.Choices[0].Message.ToolCalls[0].ID, + }, + }) + + response.ToolCalls = append(response.ToolCalls, ToolCall{ + Function: FunctionCall{ + Name: toolName, + Arguments: parsedArgs, + Result: toolResponse, + }, + }) + } + return response, nil +} + // Generate sends a prompt to the OpenAI API and returns the generated response. // // Parameters: diff --git a/pkg/backend/openai_backend_test.go b/pkg/backend/openai_backend_test.go index e088a2d..1ad13e2 100644 --- a/pkg/backend/openai_backend_test.go +++ b/pkg/backend/openai_backend_test.go @@ -33,16 +33,18 @@ func TestGenerate(t *testing.T) { Choices: []struct { Index int `json:"index"` Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` } `json:"message"` FinishReason string `json:"finish_reason"` }{ { Index: 0, Message: struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []OpenAIToolCall `json:"tool_calls"` }{ Role: "assistant", Content: "This is a test response.", diff --git a/pkg/backend/tools.go b/pkg/backend/tools.go new file mode 100644 index 0000000..b8eba9e --- /dev/null +++ b/pkg/backend/tools.go @@ -0,0 +1,90 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "fmt" + "sync" +) + +// ErrToolNotFound is returned when a tool is not found in the registry. +var ErrToolNotFound = fmt.Errorf("tool not found") + +// ToolWrapper is a function type that wraps a tool's functionality. +type ToolWrapper func(args map[string]any) (string, error) + +// Tool represents a tool that can be executed. +type Tool struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +// ToolFunction represents the function signature of a tool. +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` + Wrapper ToolWrapper `json:"-"` +} + +// ToolRegistry manages the registration of tools and their corresponding wrapper functions. +type ToolRegistry struct { + tools map[string]Tool + m sync.Mutex +} + +// NewToolRegistry initializes a new ToolRegistry. +func NewToolRegistry() *ToolRegistry { + return &ToolRegistry{ + tools: make(map[string]Tool), + } +} + +// RegisterTool allows the registration of a tool by name, expected parameters, and a wrapper function. +func (r *ToolRegistry) RegisterTool(t Tool) { + r.m.Lock() + r.tools[t.Function.Name] = t + r.m.Unlock() +} + +// ToolsMap returns a list of tools as a map of string to any. This is the format that both Ollama and OpenAI expect. +func (r *ToolRegistry) ToolsMap() ([]map[string]any, error) { + toolList := make([]map[string]any, 0, len(r.tools)) + r.m.Lock() + for _, tool := range r.tools { + tMap, err := ToMap(tool) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list to map: %w", err) + } + toolList = append(toolList, tMap) + } + r.m.Unlock() + + return toolList, nil +} + +// ExecuteTool looks up a tool by name, checks the provided arguments, and calls the registered wrapper function. +func (r *ToolRegistry) ExecuteTool(toolName string, args map[string]any) (string, error) { + r.m.Lock() + defer r.m.Unlock() + + toolEntry, exists := r.tools[toolName] + if !exists { + return "", fmt.Errorf("%w: %s", ErrToolNotFound, toolName) + } + + // Call the tool's wrapper function with the provided arguments + return toolEntry.Function.Wrapper(args) +} diff --git a/pkg/backend/utils.go b/pkg/backend/utils.go new file mode 100644 index 0000000..2db3a64 --- /dev/null +++ b/pkg/backend/utils.go @@ -0,0 +1,47 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "encoding/json" + "fmt" +) + +// ToMap converts the given value to a map[string]any. This is useful for working with JSON data. +func ToMap(v any) (map[string]any, error) { + data, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("error marshaling to JSON: %v", err) + } + + var mapResult map[string]any + err = json.Unmarshal(data, &mapResult) + if err != nil { + return nil, fmt.Errorf("error unmarshaling JSON to map: %v", err) + } + + return mapResult, nil +} + +// PrintJSON prints the given value as a JSON string. Useful for debugging. +func PrintJSON(v any) { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + fmt.Printf("error marshaling to JSON: %v", err) + return + } + + fmt.Println(string(data)) +}