-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an example with a tool talking to Trusty
Adds an example of tool function for trusty and an example program that calls the model with tools.
- Loading branch information
Showing
2 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// Copyright 2024 Stacklok, Inc | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package main | ||
|
||
import ( | ||
"context" | ||
"github.com/stackloklabs/gollm/examples/tools/trusty" | ||
"log" | ||
"os" | ||
"strings" | ||
"time" | ||
|
||
"github.com/stackloklabs/gollm/pkg/backend" | ||
) | ||
|
||
var ( | ||
ollamaHost = "http://localhost:11434" | ||
ollamaGenModel = "qwen2.5" | ||
openaiModel = "gpt-4o-mini" | ||
) | ||
|
||
const ( | ||
systemMessage = `You are a helpful AI assistant that provides recommendations to users about software packages. | ||
Your job is to provide a recommendation based on the user's prompt. | ||
You might be provided a JSON summary along with the user's prompt. Do not summarize the JSON back to the user. | ||
Focus on whether the package is malicious or deprecated based on the provided tool input you get as JSON. | ||
Focus less on the number of stars or forks. | ||
If the user does not specify the ecosystem (one of npm, pypi, crates, maven, go), ask the user, NEVER assume the ecosystem. | ||
If the package is malicious or deprecated, recommend a safer alternative. | ||
If the package is safe, recommend the package. | ||
` | ||
summarizeMessage = `Summarize the tool response for me in plain speech. If the package is either malicious, deprecated | ||
or no longer maintained, recommend a bulleted list of two-three safer alternative packages that do the same. If the package is safe, recommend the package.` | ||
) | ||
|
||
func main() { | ||
var generationBackend backend.Backend | ||
|
||
beSelection := os.Getenv("BACKEND") | ||
if beSelection == "" { | ||
log.Println("No backend selected with the BACKEND env variable. Defaulting to Ollama.") | ||
beSelection = "ollama" | ||
} | ||
modelSelection := os.Getenv("MODEL") | ||
if modelSelection == "" { | ||
switch beSelection { | ||
case "ollama": | ||
modelSelection = ollamaGenModel | ||
case "openai": | ||
modelSelection = openaiModel | ||
} | ||
log.Println("No model selected with the MODEL env variable. Defaulting to ", modelSelection) | ||
} | ||
|
||
switch beSelection { | ||
case "ollama": | ||
generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel) | ||
log.Println("Using Ollama backend: ", ollamaGenModel) | ||
case "openai": | ||
openaiKey := os.Getenv("OPENAI_API_KEY") | ||
if openaiKey == "" { | ||
log.Fatalf("OPENAI_API_KEY is required for OpenAI backend") | ||
} | ||
generationBackend = backend.NewOpenAIBackend(openaiKey, openaiModel) | ||
log.Println("Using OpenAI backend: ", openaiModel) | ||
default: | ||
log.Fatalf("Unknown backend: %s", beSelection) | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) | ||
defer cancel() | ||
|
||
userPrompt := os.Args[1:] | ||
if len(userPrompt) == 0 { | ||
log.Fatalf("Please provide a prompt") | ||
} | ||
|
||
convo := backend.NewConversation() | ||
convo.Tools.RegisterTool(trusty.Tool()) | ||
// start the conversation. We add a system message to tune the output | ||
// and add the trusty tool to the conversation so that the model knows to call it. | ||
convo.AddSystemMessage(systemMessage, nil) | ||
convo.AddUserMessage(strings.Join(userPrompt, " "), nil) | ||
|
||
// generate the response | ||
resp, err := generationBackend.Converse(ctx, convo) | ||
if err != nil { | ||
log.Fatalf("Error generating response: %v", err) | ||
} | ||
|
||
if len(resp.ToolCalls) == 0 { | ||
log.Println("No tool calls in response.") | ||
log.Println("Response:", convo.Messages[len(convo.Messages)-1]["content"]) | ||
return | ||
} | ||
|
||
log.Println("Tool called") | ||
|
||
// if there was a tool response, first just feed it back to the model so it makes sense of it | ||
_, err = generationBackend.Converse(ctx, convo) | ||
if err != nil { | ||
log.Fatalf("Error generating response: %v", err) | ||
} | ||
|
||
log.Println("Summarizing tool response") | ||
|
||
// summarize the tool response | ||
convo.AddSystemMessage(summarizeMessage, nil) | ||
_, err = generationBackend.Converse(ctx, convo) | ||
if err != nil { | ||
log.Fatalf("Error generating response: %v", err) | ||
} | ||
|
||
log.Println("Response:") | ||
log.Println(convo.Messages[len(convo.Messages)-1]["content"]) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright 2024 Stacklok, Inc | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package trusty | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"github.com/stackloklabs/gollm/pkg/backend" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
// Tool returns a backend.Tool object that can be used to interact with the trusty tool. | ||
func Tool() backend.Tool { | ||
return backend.Tool{ | ||
Type: "function", | ||
Function: backend.ToolFunction{ | ||
Name: "trusty", | ||
Description: "Evaluate the trustworthiness of a package", | ||
Parameters: map[string]any{ | ||
"type": "object", | ||
"properties": map[string]any{ | ||
"package_name": map[string]any{ | ||
"type": "string", | ||
"description": "The name of the package", | ||
}, | ||
"ecosystem": map[string]any{ | ||
"type": "string", | ||
"description": "The ecosystem of the package", | ||
"enum": []string{"npm", "pypi", "crates", "maven", "go"}, | ||
"default": "pypi", | ||
}, | ||
}, | ||
"required": []string{"package_name", "ecosystem"}, | ||
}, | ||
Wrapper: trustyReportWrapper, | ||
}, | ||
} | ||
} | ||
|
||
func trustyReportWrapper(params map[string]any) (string, error) { | ||
packageName, ok := params["package_name"].(string) | ||
if !ok { | ||
return "", fmt.Errorf("package_name must be a string") | ||
} | ||
ecosystem, ok := params["ecosystem"].(string) | ||
if !ok { | ||
ecosystem = "PyPi" | ||
} | ||
return trustyReport(packageName, ecosystem) | ||
} | ||
|
||
func trustyReport(packageName string, ecosystem string) (string, error) { | ||
url := fmt.Sprintf("https://api.trustypkg.dev/v1/report?package_name=%s&package_type=%s", packageName, ecosystem) | ||
|
||
req, err := http.NewRequest("GET", url, nil) | ||
if err != nil { | ||
return "", err | ||
} | ||
req.Header.Set("accept", "application/json") | ||
|
||
// Perform the request | ||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return "", err | ||
} | ||
defer resp.Body.Close() | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
var prettyJSON map[string]interface{} | ||
err = json.Unmarshal(body, &prettyJSON) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
// Convert the JSON back to string | ||
jsonString, err := json.MarshalIndent(prettyJSON, "", " ") | ||
if err != nil { | ||
return "", err | ||
} | ||
return string(jsonString), nil | ||
} |