Skip to content

Commit

Permalink
feat: add ai-workflow plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
lixf311 committed Aug 19, 2024
1 parent 25b085c commit f182f03
Show file tree
Hide file tree
Showing 8 changed files with 1,060 additions and 0 deletions.
2 changes: 2 additions & 0 deletions plugins/wasm-go/extensions/ai-workflow/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM scratch
COPY main.wasm plugin.wasm
369 changes: 369 additions & 0 deletions plugins/wasm-go/extensions/ai-workflow/README.md

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions plugins/wasm-go/extensions/ai-workflow/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-workflow

go 1.19

replace github.com/alibaba/higress/plugins/wasm-go => ../..

require (
github.com/alibaba/higress/plugins/wasm-go v0.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/tidwall/sjson v1.2.5
)

require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
)
23 changes: 23 additions & 0 deletions plugins/wasm-go/extensions/ai-workflow/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
267 changes: 267 additions & 0 deletions plugins/wasm-go/extensions/ai-workflow/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"errors"
"fmt"
. "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-workflow/workflow"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"net/http"
)

const (
maxDepth uint = 100
)

func main() {
wrapper.SetCtx(
"ai-workflow",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
)
}

func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {

workflows := make([]WorkFlow, 0)
tools := make(map[string]Tool)

dsl := json.Get("dsl")
if !dsl.Exists() {
return errors.New("dsl is empty")
}
//处理dsl.workflow
workFlows_ := dsl.Get("workflow")
if workFlows_.Exists() && workFlows_.IsArray() {
for _, w := range workFlows_.Array() {
task := Task{}
workflow := WorkFlow{}
workflow.Source = w.Get("source").String()
if workflow.Source == "" {
return errors.New("source is empty")
}
workflow.Target = w.Get("target").String()
if workflow.Target == "" {
return errors.New("target is empty")
}
workflow.Task = &task
//workflow.Context = make(map[string]string)
workflow.Input = w.Get("input").String()
workflow.Output = w.Get("output").String()
workflow.Conditional = w.Get("conditional").String()
workflows = append(workflows, workflow)
}
}
c.DSL.WorkFlow = workflows

//处理tools
tools_ := json.Get("tools")
if tools_.Exists() && tools_.IsArray() {

for _, value := range tools_.Array() {
tool := Tool{}
tool.Name = value.Get("name").String()
if tool.Name == "" {
return errors.New("tool name is empty")
}
tool.ServiceType = value.Get("service_type").String()
if tool.ServiceType == "" {
return errors.New("tool service type is empty")
}
tool.ServiceName = value.Get("service_name").String()
if tool.ServiceName == "" {
return errors.New("tool service name is empty")
}
tool.ServicePort = value.Get("service_port").Int()
if tool.ServicePort == 0 {
if tool.ServiceType == ToolServiceTypeStatic {
tool.ServicePort = 80
} else {
return errors.New("tool service port is empty")
}

}
tool.ServiceDomain = value.Get("service_domain").String()
tool.ServicePath = value.Get("service_path").String()
if tool.ServicePath == "" {
tool.ServicePath = "/"
}
tool.ServiceMethod = value.Get("service_method").String()
if tool.ServiceMethod == "" {
return errors.New("service_method is empty")
}
serviceHeaders := value.Get("service_headers")
if serviceHeaders.Exists() && serviceHeaders.IsArray() {
tool.ServiceHeaders = make([][2]string, 0)
for _, serviceHeader := range serviceHeaders.Array() {
if serviceHeader.IsArray() && len(serviceHeader.Array()) == 2 {
kv := serviceHeader.Array()
tool.ServiceHeaders = append(tool.ServiceHeaders, [2]string{kv[0].String(), kv[1].String()})
} else {
return errors.New("service_headers is not allow")
}

}
}
tool.ServiceBodyTmpl = value.Get("service_body_tmpl").String()
serviceBodyReplaceKeys := value.Get("service_body_replace_keys")
if serviceBodyReplaceKeys.Exists() && serviceBodyReplaceKeys.IsArray() {
tool.ServiceBodyReplaceKeys = make([][2]string, 0)
for _, serviceBodyReplaceKey := range serviceBodyReplaceKeys.Array() {
if serviceBodyReplaceKey.IsArray() && len(serviceBodyReplaceKey.Array()) == 2 {
keys := serviceBodyReplaceKey.Array()
tool.ServiceBodyReplaceKeys = append(tool.ServiceBodyReplaceKeys, [2]string{keys[0].String(), keys[1].String()})
} else {
return errors.New("service body replace keys is not allow")
}
}
}
tools[tool.Name] = tool
}
c.Tools = tools
}
log.Debugf("config : %v", c)
return nil
}

func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {

initHeader := make([][2]string, 0)
err := recursive(config.DSL.WorkFlow, initHeader, body, 1, maxDepth, config, log, ctx)
if err != nil {
log.Errorf("recursive failed: %v", err)
}
return types.ActionPause
}

func recursive(workflows []WorkFlow, headers [][2]string, body []byte, depth uint, maxDepth uint, config PluginConfig, log wrapper.Log, ctx wrapper.HttpContext) error {

var err error
// 防止递归次数太多
if depth > maxDepth {
return fmt.Errorf("maximum recursion depth reached")
}
step := depth - 1

log.Debugf("workflow is %v", workflows[step])
workflow := workflows[step]

// 执行判断Conditional
if workflow.Conditional != "" {
//填充Conditional
workflow.Conditional, err = workflow.WrapperDataByTmplStr(workflow.Conditional, body, ctx)
if err != nil {
log.Errorf("workflow WrapperDateByTmplStr %s failed: %v", workflow.Conditional, err)
return fmt.Errorf("workflow WrapperDateByTmplStr %s failed: %v", workflow.Conditional, err)
}
log.Debugf("Exec Conditional is %s", workflow.Conditional)
ok, err := workflow.ExecConditional()
if err != nil {
log.Errorf("wl exec conditional %s failed: %v", workflow.Conditional, err)
return fmt.Errorf("wl exec conditional %s failed: %v", workflow.Conditional, err)
}
//如果不通过直接跳过这步
if !ok {
log.Debugf("workflow is pass")
err = recursive(workflows, headers, body, depth+1, maxDepth, config, log, ctx)
if err != nil {

return err
}
return nil
}
}
//判断是不是end
if workflow.IsEnd() {
log.Debugf("workflow is end")
log.Debugf("body is %s", string(body))
proxywasm.SendHttpResponse(200, headers, body, -1)
return nil
}
//判断是不是continue
if workflow.IsContinue() {
log.Debugf("workflow is continue")
proxywasm.ResumeHttpRequest()
return nil
}

// 过滤input
if workflow.Input != "" {
inputJson := gjson.GetBytes(body, workflow.Input)
if inputJson.Exists() {
body = []byte(inputJson.Raw)
} else {
return fmt.Errorf("input filter get path %s is not found,json is %s", workflow.Input, string(body))
}
}
// 存入这轮请求的body
ctx.SetContext(fmt.Sprintf("%s-input", workflow.Target), body)
// 封装task
err = workflow.WrapperTask(config, ctx)
if err != nil {
log.Errorf("workflow exec wrapperTask find error,source is %s,target is %s,error is %v ", workflow.Source, workflow.Target, err)
return fmt.Errorf("workflow exec wrapperTask find error,source is %s,target is %s,error is %v ", workflow.Source, workflow.Target, err)
}

//执行task
log.Debugf("workflow exec task,source is %s,target is %s, body is %s,header is %v", workflow.Source, workflow.Target, string(workflow.Task.Body), workflow.Task.Headers)
err = wrapper.HttpCall(workflow.Task.Cluster, workflow.Task.Method, workflow.Task.ServicePath, workflow.Task.Headers, workflow.Task.Body, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("code:%d", statusCode)
//判断response code
if statusCode < 400 {
if workflow.Output != "" {
out := gjson.GetBytes(responseBody, workflow.Output)
if out.Exists() {
responseBody = []byte(out.Raw)
} else {
log.Errorf("workflow get path %s exec response body %s not found", workflow.Output, string(responseBody))
proxywasm.ResumeHttpRequest()
return
}
}
//存入 这轮返回的body
ctx.SetContext(fmt.Sprintf("%s-output", workflow.Target), responseBody)

headers_ := make([][2]string, len(responseHeaders))
for key, value := range responseHeaders {
headers_ = append(headers_, [2]string{key, value[0]})
}
//进入下一步
log.Debugf("workflow exec response body %s ", string(responseBody))
err = recursive(workflows, headers_, responseBody, depth+1, maxDepth, config, log, ctx)

if err != nil {
log.Errorf("recursive error:%v", err)
proxywasm.ResumeHttpRequest()
return
}
} else {
//statusCode >= 400 ,task httpCall执行失败,放行请求,打印错误,结束workflow
log.Errorf("workflow exec task find error,code is %d,body is %s", statusCode, string(responseBody))
proxywasm.ResumeHttpRequest()
}
return

}, uint32(maxDepth-step)*5000)
if err != nil {
log.Errorf("httpcall error:%v", err)
}

return err
}
76 changes: 76 additions & 0 deletions plugins/wasm-go/extensions/ai-workflow/utils/conditional.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package utils

import (
"fmt"
"regexp"
"strconv"
"strings"
)

/*
用来比较数据大小:
eq arg1 arg2: arg1 == arg2时为true
ne arg1 arg2: arg1 != arg2时为true
lt arg1 arg2: arg1 < arg2时为true
le arg1 arg2: arg1 <= arg2时为true
gt arg1 arg2: arg1 > arg2时为true
ge arg1 arg2: arg1 >= arg2时为true
*/
type CompareFunc func(a, b float64) bool

var operators = map[string]interface{}{
//CompareFunc 用来比较float64 数据大小:
"eq": func(a, b float64) bool { return a == b },
"ge": func(a, b float64) bool { return a >= b },
"le": func(a, b float64) bool { return a <= b },
"gt": func(a, b float64) bool { return a > b },
"lt": func(a, b float64) bool { return a < b },
//todo 添加别的判断函数
}

// 执行判断条件
func ExecConditionalStr(ConditionalStr string) (bool, error) {
fields := strings.Fields(ConditionalStr)
if len(fields) != 3 {
return false, fmt.Errorf("invalid conditional str %s,fields num is %d", ConditionalStr, len(fields))
}
compareFunc := operators[fields[0]]
switch fc := compareFunc.(type) {
default:
return false, fmt.Errorf("invalid conditional func %v", compareFunc)
case func(a, b float64) bool:
a, err := strconv.ParseFloat(fields[1], 64)
if err != nil {
return false, fmt.Errorf("invalid conditional str %s", ConditionalStr)
}
b, err := strconv.ParseFloat(fields[2], 64)
if err != nil {
return false, fmt.Errorf("invalid conditional str %s", ConditionalStr)
}
return fc(a, b), nil
}

}

// 通过正泽表达式寻找模板中的 {{foo}} 字符串foo
// 返回 {{foo}} : foo
func ParseTmplStr(tmpl string) map[string]string {
result := make(map[string]string)
re := regexp.MustCompile(`\{\{(.*?)\}\}`)
matches := re.FindAllStringSubmatch(tmpl, -1)
for _, match := range matches {
result[match[0]] = match[1]
}
return result
}

// 使用kv替换模板中的字符
// 例如 模板是`hello,{{foo}}` 使用{"{{foo}}":"bot"} 替换后为`hello,bot`
func ReplacedStr(tmpl string, kvs map[string]string) string {

for k, v := range kvs {
tmpl = strings.Replace(tmpl, k, v, -1)
}

return tmpl
}
Loading

0 comments on commit f182f03

Please sign in to comment.