Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for calling tools in conversations #4

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions examples/tools/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2024 Stacklok, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"log"
"os"
"strings"
"time"

"github.com/stackloklabs/gollm/examples/tools/weather"
"github.com/stackloklabs/gollm/pkg/backend"
)

var (
ollamaHost = "http://localhost:11434"
ollamaGenModel = "qwen2.5"
openaiModel = "gpt-4o-mini"
)

const (
systemMessage = `
You are an AI assistant that can retrieve weather forecasts by calling a tool
that returns weather data in JSON format. You will be provided with a city
name, and you will use a tool to call out to a weather forecast API that
provides the weather for that city. The tool returns a JSON object with three
fields: city, temperature, and conditions.
`
summarizeMessage = `
Summarize the tool's forecast of the weather in clear, plain language for the user.
`
)

func main() {
var generationBackend backend.Backend

beSelection := os.Getenv("BACKEND")
if beSelection == "" {
log.Println("No backend selected with the BACKEND env variable. Defaulting to Ollama.")
beSelection = "ollama"
}
modelSelection := os.Getenv("MODEL")
if modelSelection == "" {
switch beSelection {
case "ollama":
modelSelection = ollamaGenModel
case "openai":
modelSelection = openaiModel
}
log.Println("No model selected with the MODEL env variable. Defaulting to ", modelSelection)
}

switch beSelection {
case "ollama":
generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel, 30*time.Second)
log.Println("Using Ollama backend: ", ollamaGenModel)
case "openai":
openaiKey := os.Getenv("OPENAI_API_KEY")
if openaiKey == "" {
log.Fatalf("OPENAI_API_KEY is required for OpenAI backend")
}
generationBackend = backend.NewOpenAIBackend(openaiKey, openaiModel, 30*time.Second)
log.Println("Using OpenAI backend: ", openaiModel)
default:
log.Fatalf("Unknown backend: %s", beSelection)
}

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

userPrompt := os.Args[1:]
if len(userPrompt) == 0 {
log.Fatalf("Please provide a prompt")
}

convo := backend.NewPrompt()
convo.Tools.RegisterTool(weather.Tool())
// start the conversation. We add a system message to tune the output
// and add the weather tool to the conversation so that the model knows to call it.
convo.AddMessage("system", systemMessage)
convo.AddMessage("user", strings.Join(userPrompt, " "))

// generate the response
resp, err := generationBackend.Converse(ctx, convo)
if err != nil {
log.Fatalf("Error generating response: %v", err)
}

if len(resp.ToolCalls) == 0 {
log.Println("No tool calls in response.")
log.Println("Response:", convo.Messages[len(convo.Messages)-1].Content)
return
}

log.Println("Tool called")

// if there was a tool response, first just feed it back to the model so it makes sense of it
_, err = generationBackend.Converse(ctx, convo)
if err != nil {
log.Fatalf("Error generating response: %v", err)
}

log.Println("Response:")
log.Println(convo.Messages[len(convo.Messages)-1].Content)
}
79 changes: 79 additions & 0 deletions examples/tools/weather/weather_tool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2024 Stacklok, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package weather

import (
"encoding/json"
"errors"
"fmt"
"github.com/stackloklabs/gollm/pkg/backend"
)

// Tool returns a backend.Tool object that can be used to interact with the weather tool.
func Tool() backend.Tool {
return backend.Tool{
Type: "function",
Function: backend.ToolFunction{
Name: "weather",
Description: "Get weather report for a city",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"city": map[string]any{
"type": "string",
"description": "The city for which to get the weather report",
},
},
"required": []string{"city"},
},
Wrapper: weatherReportWrapper,
},
}
}

func weatherReportWrapper(params map[string]any) (string, error) {
city, ok := params["city"].(string)
if !ok {
return "", fmt.Errorf("city must be a string")
}
return weatherReport(city)
}

// WeatherReport defines the structure of the JSON response
type WeatherReport struct {
City string `json:"city"`
Temperature string `json:"temperature"`
Conditions string `json:"conditions"`
}

// weatherReport returns a dummy weather report for the specified city in JSON format.
func weatherReport(city string) (string, error) {
// in a real application, this data would be fetched from an external API
weatherData := map[string]WeatherReport{
"London": {City: "London", Temperature: "15°C", Conditions: "Rainy"},
"Stockholm": {City: "Stockholm", Temperature: "10°C", Conditions: "Sunny"},
"Brno": {City: "Brno", Temperature: "18°C", Conditions: "Clear skies"},
}

if report, ok := weatherData[city]; ok {
jsonReport, err := json.Marshal(report)
if err != nil {
return "", err
}
return string(jsonReport), nil
}

return "", errors.New("city not found")
}
61 changes: 57 additions & 4 deletions pkg/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import "context"
import (
"context"
)

// Backend defines the interface for interacting with various LLM backends.
type Backend interface {
Converse(ctx context.Context, prompt *Prompt) (PromptResponse, error)
Generate(ctx context.Context, prompt *Prompt) (string, error)
Embed(ctx context.Context, input string) ([]float32, error)
}

// Message represents a single role-based message in the conversation.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
Fields map[string]any `json:"fields,omitempty"`
}

// Parameters defines generation settings for LLM completions.
Expand All @@ -40,11 +45,17 @@ type Parameters struct {
type Prompt struct {
Messages []Message `json:"messages"`
Parameters Parameters `json:"parameters"`
// ToolRegistry is a map of tool names to their corresponding wrapper functions.
Tools *ToolRegistry
}

// NewPrompt creates and returns a new Prompt.
func NewPrompt() *Prompt {
return &Prompt{}
return &Prompt{
Messages: make([]Message, 0),
Parameters: Parameters{},
Tools: NewToolRegistry(),
}
}

// AddMessage adds a message with a specific role to the prompt.
Expand All @@ -53,8 +64,50 @@ func (p *Prompt) AddMessage(role, content string) *Prompt {
return p
}

// AppendMessage adds a message with a specific role to the prompt.
func (p *Prompt) AppendMessage(msg Message) *Prompt {
p.Messages = append(p.Messages, msg)
return p
}

// SetParameters sets the generation parameters for the prompt.
func (p *Prompt) SetParameters(params Parameters) *Prompt {
p.Parameters = params
return p
}

// AsMap returns the conversation's messages as a list of maps.
func (p *Prompt) AsMap() ([]map[string]any, error) {
messageList := make([]map[string]any, 0, len(p.Messages))
for _, message := range p.Messages {
msgMap := map[string]any{
"role": message.Role,
"content": message.Content,
}
for k, v := range message.Fields {
msgMap[k] = v
}
messageList = append(messageList, msgMap)
}

return messageList, nil
}

// FunctionCall represents a call to a function.
type FunctionCall struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
Result any `json:"result"`
}

// ToolCall represents a call to a tool.
type ToolCall struct {
Function FunctionCall `json:"function"`
}

// PromptResponse represents a response from the AI in a conversation.
type PromptResponse struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
Loading