-
Notifications
You must be signed in to change notification settings - Fork 509
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
lixf311
committed
Aug 19, 2024
1 parent
25b085c
commit f182f03
Showing
8 changed files
with
1,060 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
FROM scratch | ||
COPY main.wasm plugin.wasm |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-workflow | ||
|
||
go 1.19 | ||
|
||
replace github.com/alibaba/higress/plugins/wasm-go => ../.. | ||
|
||
require ( | ||
github.com/alibaba/higress/plugins/wasm-go v0.0.0 | ||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f | ||
github.com/tidwall/gjson v1.14.3 | ||
github.com/tidwall/sjson v1.2.5 | ||
) | ||
|
||
require ( | ||
github.com/google/uuid v1.3.0 // indirect | ||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect | ||
github.com/magefile/mage v1.14.0 // indirect | ||
github.com/tidwall/match v1.1.1 // indirect | ||
github.com/tidwall/pretty v1.2.0 // indirect | ||
github.com/tidwall/resp v0.1.1 // indirect | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= | ||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= | ||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= | ||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= | ||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= | ||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= | ||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= | ||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= | ||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= | ||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= | ||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= | ||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= | ||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= | ||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= | ||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= | ||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= | ||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= | ||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= | ||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= | ||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
// Copyright (c) 2022 Alibaba Group Holding Ltd. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package main | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
. "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-workflow/workflow" | ||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" | ||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" | ||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" | ||
"github.com/tidwall/gjson" | ||
"net/http" | ||
) | ||
|
||
const ( | ||
maxDepth uint = 100 | ||
) | ||
|
||
func main() { | ||
wrapper.SetCtx( | ||
"ai-workflow", | ||
wrapper.ParseConfigBy(parseConfig), | ||
wrapper.ProcessRequestBodyBy(onHttpRequestBody), | ||
) | ||
} | ||
|
||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { | ||
|
||
workflows := make([]WorkFlow, 0) | ||
tools := make(map[string]Tool) | ||
|
||
dsl := json.Get("dsl") | ||
if !dsl.Exists() { | ||
return errors.New("dsl is empty") | ||
} | ||
//处理dsl.workflow | ||
workFlows_ := dsl.Get("workflow") | ||
if workFlows_.Exists() && workFlows_.IsArray() { | ||
for _, w := range workFlows_.Array() { | ||
task := Task{} | ||
workflow := WorkFlow{} | ||
workflow.Source = w.Get("source").String() | ||
if workflow.Source == "" { | ||
return errors.New("source is empty") | ||
} | ||
workflow.Target = w.Get("target").String() | ||
if workflow.Target == "" { | ||
return errors.New("target is empty") | ||
} | ||
workflow.Task = &task | ||
//workflow.Context = make(map[string]string) | ||
workflow.Input = w.Get("input").String() | ||
workflow.Output = w.Get("output").String() | ||
workflow.Conditional = w.Get("conditional").String() | ||
workflows = append(workflows, workflow) | ||
} | ||
} | ||
c.DSL.WorkFlow = workflows | ||
|
||
//处理tools | ||
tools_ := json.Get("tools") | ||
if tools_.Exists() && tools_.IsArray() { | ||
|
||
for _, value := range tools_.Array() { | ||
tool := Tool{} | ||
tool.Name = value.Get("name").String() | ||
if tool.Name == "" { | ||
return errors.New("tool name is empty") | ||
} | ||
tool.ServiceType = value.Get("service_type").String() | ||
if tool.ServiceType == "" { | ||
return errors.New("tool service type is empty") | ||
} | ||
tool.ServiceName = value.Get("service_name").String() | ||
if tool.ServiceName == "" { | ||
return errors.New("tool service name is empty") | ||
} | ||
tool.ServicePort = value.Get("service_port").Int() | ||
if tool.ServicePort == 0 { | ||
if tool.ServiceType == ToolServiceTypeStatic { | ||
tool.ServicePort = 80 | ||
} else { | ||
return errors.New("tool service port is empty") | ||
} | ||
|
||
} | ||
tool.ServiceDomain = value.Get("service_domain").String() | ||
tool.ServicePath = value.Get("service_path").String() | ||
if tool.ServicePath == "" { | ||
tool.ServicePath = "/" | ||
} | ||
tool.ServiceMethod = value.Get("service_method").String() | ||
if tool.ServiceMethod == "" { | ||
return errors.New("service_method is empty") | ||
} | ||
serviceHeaders := value.Get("service_headers") | ||
if serviceHeaders.Exists() && serviceHeaders.IsArray() { | ||
tool.ServiceHeaders = make([][2]string, 0) | ||
for _, serviceHeader := range serviceHeaders.Array() { | ||
if serviceHeader.IsArray() && len(serviceHeader.Array()) == 2 { | ||
kv := serviceHeader.Array() | ||
tool.ServiceHeaders = append(tool.ServiceHeaders, [2]string{kv[0].String(), kv[1].String()}) | ||
} else { | ||
return errors.New("service_headers is not allow") | ||
} | ||
|
||
} | ||
} | ||
tool.ServiceBodyTmpl = value.Get("service_body_tmpl").String() | ||
serviceBodyReplaceKeys := value.Get("service_body_replace_keys") | ||
if serviceBodyReplaceKeys.Exists() && serviceBodyReplaceKeys.IsArray() { | ||
tool.ServiceBodyReplaceKeys = make([][2]string, 0) | ||
for _, serviceBodyReplaceKey := range serviceBodyReplaceKeys.Array() { | ||
if serviceBodyReplaceKey.IsArray() && len(serviceBodyReplaceKey.Array()) == 2 { | ||
keys := serviceBodyReplaceKey.Array() | ||
tool.ServiceBodyReplaceKeys = append(tool.ServiceBodyReplaceKeys, [2]string{keys[0].String(), keys[1].String()}) | ||
} else { | ||
return errors.New("service body replace keys is not allow") | ||
} | ||
} | ||
} | ||
tools[tool.Name] = tool | ||
} | ||
c.Tools = tools | ||
} | ||
log.Debugf("config : %v", c) | ||
return nil | ||
} | ||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { | ||
|
||
initHeader := make([][2]string, 0) | ||
err := recursive(config.DSL.WorkFlow, initHeader, body, 1, maxDepth, config, log, ctx) | ||
if err != nil { | ||
log.Errorf("recursive failed: %v", err) | ||
} | ||
return types.ActionPause | ||
} | ||
|
||
func recursive(workflows []WorkFlow, headers [][2]string, body []byte, depth uint, maxDepth uint, config PluginConfig, log wrapper.Log, ctx wrapper.HttpContext) error { | ||
|
||
var err error | ||
// 防止递归次数太多 | ||
if depth > maxDepth { | ||
return fmt.Errorf("maximum recursion depth reached") | ||
} | ||
step := depth - 1 | ||
|
||
log.Debugf("workflow is %v", workflows[step]) | ||
workflow := workflows[step] | ||
|
||
// 执行判断Conditional | ||
if workflow.Conditional != "" { | ||
//填充Conditional | ||
workflow.Conditional, err = workflow.WrapperDataByTmplStr(workflow.Conditional, body, ctx) | ||
if err != nil { | ||
log.Errorf("workflow WrapperDateByTmplStr %s failed: %v", workflow.Conditional, err) | ||
return fmt.Errorf("workflow WrapperDateByTmplStr %s failed: %v", workflow.Conditional, err) | ||
} | ||
log.Debugf("Exec Conditional is %s", workflow.Conditional) | ||
ok, err := workflow.ExecConditional() | ||
if err != nil { | ||
log.Errorf("wl exec conditional %s failed: %v", workflow.Conditional, err) | ||
return fmt.Errorf("wl exec conditional %s failed: %v", workflow.Conditional, err) | ||
} | ||
//如果不通过直接跳过这步 | ||
if !ok { | ||
log.Debugf("workflow is pass") | ||
err = recursive(workflows, headers, body, depth+1, maxDepth, config, log, ctx) | ||
if err != nil { | ||
|
||
return err | ||
} | ||
return nil | ||
} | ||
} | ||
//判断是不是end | ||
if workflow.IsEnd() { | ||
log.Debugf("workflow is end") | ||
log.Debugf("body is %s", string(body)) | ||
proxywasm.SendHttpResponse(200, headers, body, -1) | ||
return nil | ||
} | ||
//判断是不是continue | ||
if workflow.IsContinue() { | ||
log.Debugf("workflow is continue") | ||
proxywasm.ResumeHttpRequest() | ||
return nil | ||
} | ||
|
||
// 过滤input | ||
if workflow.Input != "" { | ||
inputJson := gjson.GetBytes(body, workflow.Input) | ||
if inputJson.Exists() { | ||
body = []byte(inputJson.Raw) | ||
} else { | ||
return fmt.Errorf("input filter get path %s is not found,json is %s", workflow.Input, string(body)) | ||
} | ||
} | ||
// 存入这轮请求的body | ||
ctx.SetContext(fmt.Sprintf("%s-input", workflow.Target), body) | ||
// 封装task | ||
err = workflow.WrapperTask(config, ctx) | ||
if err != nil { | ||
log.Errorf("workflow exec wrapperTask find error,source is %s,target is %s,error is %v ", workflow.Source, workflow.Target, err) | ||
return fmt.Errorf("workflow exec wrapperTask find error,source is %s,target is %s,error is %v ", workflow.Source, workflow.Target, err) | ||
} | ||
|
||
//执行task | ||
log.Debugf("workflow exec task,source is %s,target is %s, body is %s,header is %v", workflow.Source, workflow.Target, string(workflow.Task.Body), workflow.Task.Headers) | ||
err = wrapper.HttpCall(workflow.Task.Cluster, workflow.Task.Method, workflow.Task.ServicePath, workflow.Task.Headers, workflow.Task.Body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { | ||
log.Debugf("code:%d", statusCode) | ||
//判断response code | ||
if statusCode < 400 { | ||
if workflow.Output != "" { | ||
out := gjson.GetBytes(responseBody, workflow.Output) | ||
if out.Exists() { | ||
responseBody = []byte(out.Raw) | ||
} else { | ||
log.Errorf("workflow get path %s exec response body %s not found", workflow.Output, string(responseBody)) | ||
proxywasm.ResumeHttpRequest() | ||
return | ||
} | ||
} | ||
//存入 这轮返回的body | ||
ctx.SetContext(fmt.Sprintf("%s-output", workflow.Target), responseBody) | ||
|
||
headers_ := make([][2]string, len(responseHeaders)) | ||
for key, value := range responseHeaders { | ||
headers_ = append(headers_, [2]string{key, value[0]}) | ||
} | ||
//进入下一步 | ||
log.Debugf("workflow exec response body %s ", string(responseBody)) | ||
err = recursive(workflows, headers_, responseBody, depth+1, maxDepth, config, log, ctx) | ||
|
||
if err != nil { | ||
log.Errorf("recursive error:%v", err) | ||
proxywasm.ResumeHttpRequest() | ||
return | ||
} | ||
} else { | ||
//statusCode >= 400 ,task httpCall执行失败,放行请求,打印错误,结束workflow | ||
log.Errorf("workflow exec task find error,code is %d,body is %s", statusCode, string(responseBody)) | ||
proxywasm.ResumeHttpRequest() | ||
} | ||
return | ||
|
||
}, uint32(maxDepth-step)*5000) | ||
if err != nil { | ||
log.Errorf("httpcall error:%v", err) | ||
} | ||
|
||
return err | ||
} |
76 changes: 76 additions & 0 deletions
76
plugins/wasm-go/extensions/ai-workflow/utils/conditional.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package utils | ||
|
||
import ( | ||
"fmt" | ||
"regexp" | ||
"strconv" | ||
"strings" | ||
) | ||
|
||
/* | ||
用来比较数据大小: | ||
eq arg1 arg2: arg1 == arg2时为true | ||
ne arg1 arg2: arg1 != arg2时为true | ||
lt arg1 arg2: arg1 < arg2时为true | ||
le arg1 arg2: arg1 <= arg2时为true | ||
gt arg1 arg2: arg1 > arg2时为true | ||
ge arg1 arg2: arg1 >= arg2时为true | ||
*/ | ||
type CompareFunc func(a, b float64) bool | ||
|
||
var operators = map[string]interface{}{ | ||
//CompareFunc 用来比较float64 数据大小: | ||
"eq": func(a, b float64) bool { return a == b }, | ||
"ge": func(a, b float64) bool { return a >= b }, | ||
"le": func(a, b float64) bool { return a <= b }, | ||
"gt": func(a, b float64) bool { return a > b }, | ||
"lt": func(a, b float64) bool { return a < b }, | ||
//todo 添加别的判断函数 | ||
} | ||
|
||
// 执行判断条件 | ||
func ExecConditionalStr(ConditionalStr string) (bool, error) { | ||
fields := strings.Fields(ConditionalStr) | ||
if len(fields) != 3 { | ||
return false, fmt.Errorf("invalid conditional str %s,fields num is %d", ConditionalStr, len(fields)) | ||
} | ||
compareFunc := operators[fields[0]] | ||
switch fc := compareFunc.(type) { | ||
default: | ||
return false, fmt.Errorf("invalid conditional func %v", compareFunc) | ||
case func(a, b float64) bool: | ||
a, err := strconv.ParseFloat(fields[1], 64) | ||
if err != nil { | ||
return false, fmt.Errorf("invalid conditional str %s", ConditionalStr) | ||
} | ||
b, err := strconv.ParseFloat(fields[2], 64) | ||
if err != nil { | ||
return false, fmt.Errorf("invalid conditional str %s", ConditionalStr) | ||
} | ||
return fc(a, b), nil | ||
} | ||
|
||
} | ||
|
||
// 通过正泽表达式寻找模板中的 {{foo}} 字符串foo | ||
// 返回 {{foo}} : foo | ||
func ParseTmplStr(tmpl string) map[string]string { | ||
result := make(map[string]string) | ||
re := regexp.MustCompile(`\{\{(.*?)\}\}`) | ||
matches := re.FindAllStringSubmatch(tmpl, -1) | ||
for _, match := range matches { | ||
result[match[0]] = match[1] | ||
} | ||
return result | ||
} | ||
|
||
// 使用kv替换模板中的字符 | ||
// 例如 模板是`hello,{{foo}}` 使用{"{{foo}}":"bot"} 替换后为`hello,bot` | ||
func ReplacedStr(tmpl string, kvs map[string]string) string { | ||
|
||
for k, v := range kvs { | ||
tmpl = strings.Replace(tmpl, k, v, -1) | ||
} | ||
|
||
return tmpl | ||
} |
Oops, something went wrong.