Skip to content

Commit fc9d1e2

Browse files
Added vertex ai prompt attribute
1 parent 993742e commit fc9d1e2

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

client/vertexai/instance.go

-5
This file was deleted.

client/vertexai/predictrequest.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package vertexai
22

33
type predictRequest struct {
4-
Instances []instance `json:"instances"`
5-
Parameters parameters `json:"parameters"`
4+
Instances []map[string]interface{} `json:"instances"`
5+
Parameters parameters `json:"parameters"`
66
}

client/vertexai/vertexaiclient.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/spandigitial/codeassistant/client"
88
"github.com/spandigitial/codeassistant/client/debugger"
99
"github.com/spandigitial/codeassistant/model"
10+
"github.com/spf13/viper"
1011
"io"
1112
"net/http"
1213
"time"
@@ -74,14 +75,15 @@ func (c *Client) Completion(commandInstance *model.CommandInstance, messageParts
7475
TopK: topK,
7576
}
7677

78+
prompt := commandInstance.JoinedPromptsContent("\n")
7779
request := predictRequest{
78-
Instances: []instance{{
79-
Content: commandInstance.JoinedPromptsContent("\n"),
80+
Instances: []map[string]interface{}{{
81+
viper.GetString("vertexAiPromptAttribute"): prompt,
8082
}},
8183
Parameters: parameters,
8284
}
8385

84-
c.debugger.Message(debugger.SentPrompt, request.Instances[0].Content)
86+
c.debugger.Message(debugger.SentPrompt, prompt)
8587

8688
requestBytes, err := json.Marshal(request)
8789
if err != nil {

cmd/root.go

+4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ func init() {
9494
if err := viper.BindPFlag("vertexAiModel", rootCmd.PersistentFlags().Lookup("vertexAiModel")); err != nil {
9595
log.Fatal("Unable to find flag vertexAiModel", err)
9696
}
97+
rootCmd.PersistentFlags().String("vertexAiPromptAttribute", "content", "Model to use if not specified")
98+
if err := viper.BindPFlag("vertexAiPromptAttribute", rootCmd.PersistentFlags().Lookup("vertexAiPromptAttribute")); err != nil {
99+
log.Fatal("Unable to find flag vertexAiModel", err)
100+
}
97101
rootCmd.PersistentFlags().String("vertexAiLocation", "us-central1", "Locstion to use if not specified")
98102
if err := viper.BindPFlag("vertexAiLocation", rootCmd.PersistentFlags().Lookup("vertexAiLocation")); err != nil {
99103
log.Fatal("Unable to find flag vertexAiLocation", err)

0 commit comments

Comments
 (0)