Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support manage connection #2867

Merged
merged 24 commits into from
May 27, 2024
Merged
175 changes: 175 additions & 0 deletions internal/io/connection/connection_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connection

import (
"encoding/json"
"fmt"
"strings"
"sync"

"github.com/lf-edge/ekuiper/contract/v2/api"
"github.com/lf-edge/ekuiper/v2/internal/io/mqtt/client"
Yisaer marked this conversation as resolved.
Show resolved Hide resolved
"github.com/lf-edge/ekuiper/v2/internal/pkg/store"
"github.com/lf-edge/ekuiper/v2/internal/topo/context"
"github.com/lf-edge/ekuiper/v2/pkg/kv"
)

type Connection interface {
Yisaer marked this conversation as resolved.
Show resolved Hide resolved
Ping() error
Close()
Yisaer marked this conversation as resolved.
Show resolved Hide resolved
Attach()
Yisaer marked this conversation as resolved.
Show resolved Hide resolved
DetachSub(props map[string]any)
DetachPub(props map[string]any)
Ref() int
}

func GetAllConnectionsID() []string {
globalConnectionManager.RLock()
defer globalConnectionManager.RUnlock()
ids := make([]string, 0)
for key := range globalConnectionManager.connectionPool {
ids = append(ids, key)
}
return ids
}

func PingConnection(id string) error {
conn, err := GetNameConnection(id)
if err != nil {
return err
}
return conn.Ping()
}

func GetNameConnection(selId string) (Connection, error) {
if selId == "" {
return nil, fmt.Errorf("connection id should be defined")
}
globalConnectionManager.RLock()
defer globalConnectionManager.RUnlock()
meta, ok := globalConnectionManager.connectionPool[selId]
if !ok {
return nil, fmt.Errorf("connection %s not existed", selId)
}
return meta.conn, nil
}

func CreateNamedConnection(ctx api.StreamContext, id, typ string, props map[string]any) error {
if id == "" || typ == "" {
return fmt.Errorf("connection id and type should be defined")
}
globalConnectionManager.Lock()
defer globalConnectionManager.Unlock()
_, ok := globalConnectionManager.connectionPool[id]
if ok {
return fmt.Errorf("connection %v already been created", id)
}
meta := ConnectionMeta{
ID: id,
Typ: typ,
Props: props,
}
b, err := json.Marshal(meta)
if err != nil {
return err
}
if err := globalConnectionManager.store.Set(id, string(b)); err != nil {
return err
}
conn, err := createNamedConnection(ctx, meta)
if err != nil {
return err
}
meta.conn = conn
globalConnectionManager.connectionPool[id] = meta
return nil
}

func createNamedConnection(ctx api.StreamContext, meta ConnectionMeta) (Connection, error) {
var conn Connection
var err error
switch strings.ToLower(meta.Typ) {
case "mqtt":
Yisaer marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Will named connection support read configuration by confKey?
conn, err = client.CreateClient(ctx, meta.ID, meta.Props)
default:
err = fmt.Errorf("unknown connection type")
}
if err != nil {
return nil, err
}
return conn, nil
}

func DropNameConnection(selId string) error {
if selId == "" {
return fmt.Errorf("connection id should be defined")
}
globalConnectionManager.Lock()
defer globalConnectionManager.Unlock()
meta, ok := globalConnectionManager.connectionPool[selId]
if !ok {
return nil
}
conn := meta.conn
if conn.Ref() > 0 {
return fmt.Errorf("connection %s can't be dropped due to reference", selId)
}
err := globalConnectionManager.store.Delete(selId)
if err != nil {
return fmt.Errorf("drop connection %s failed, err:%v", selId, err)
}
conn.Close()
delete(globalConnectionManager.connectionPool, selId)
return nil
}

var globalConnectionManager *ConnectionManager

func InitConnectionManager() error {
globalConnectionManager = &ConnectionManager{
connectionPool: make(map[string]ConnectionMeta),
}
globalConnectionManager.store, _ = store.GetKV("connectionMeta")
kvs, _ := globalConnectionManager.store.All()
for connectionID, raw := range kvs {
meta := ConnectionMeta{}
err := json.Unmarshal([]byte(raw), &meta)
if err != nil {
return fmt.Errorf("initialize connection:%v failed, err:%v", connectionID, err)
}
conn, err := createNamedConnection(context.Background(), meta)
if err != nil {
return fmt.Errorf("initialize connection:%v failed, err:%v", connectionID, err)
}
meta.conn = conn
globalConnectionManager.connectionPool[connectionID] = meta
}
return nil
}

type ConnectionManager struct {
sync.RWMutex
store kv.KeyValue
ngjaying marked this conversation as resolved.
Show resolved Hide resolved
connectionPool map[string]ConnectionMeta
}

type ConnectionMeta struct {
ID string `json:"id"`
Typ string `json:"typ"`
Props map[string]any `json:"props"`
conn Connection `json:"-"`
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package mqtt
package client

import (
"crypto/tls"
Expand All @@ -25,7 +25,6 @@ import (
"github.com/google/uuid"

"github.com/lf-edge/ekuiper/contract/v2/api"
"github.com/lf-edge/ekuiper/v2/internal/conf"
"github.com/lf-edge/ekuiper/v2/pkg/cast"
"github.com/lf-edge/ekuiper/v2/pkg/cert"
"github.com/lf-edge/ekuiper/v2/pkg/errorx"
Expand Down Expand Up @@ -90,27 +89,27 @@ func (conn *Connection) onReconnecting(_ pahoMqtt.Client, _ *pahoMqtt.ClientOpti
}

// Do not call this directly. Call connection pool Attach method to get the connection
func (conn *Connection) attach() {
func (conn *Connection) Attach() {
conn.refCount.Add(1)
}

func (conn *Connection) Ref() int {
return int(conn.refCount.Load())
}

// Do not call this directly. Call connection pool Detach method to release the connection
func (conn *Connection) detachSub(topic string) bool {
func (conn *Connection) DetachSub(props map[string]any) {
conn.refCount.Add(-1)
topic, err := getTopicFromProps(props)
if err != nil {
return
}
delete(conn.subscriptions, topic)
conn.Client.Unsubscribe(topic)
if conn.refCount.Add(-1) == 0 {
go conn.Close()
return true
}
return false
}

func (conn *Connection) detachPub() bool {
if conn.refCount.Add(-1) == 0 {
go conn.Close()
return true
}
return false
func (conn *Connection) DetachPub(props map[string]any) {
conn.refCount.Add(-1)
}

func (conn *Connection) Subscribe(topic string, info *SubscriptionInfo) error {
Expand All @@ -133,20 +132,7 @@ func (conn *Connection) Ping() error {

// CreateClient creates a new mqtt client. It is anonymous and does not require a name.
func CreateClient(ctx api.StreamContext, selId string, props map[string]any) (*Connection, error) {
if selId != "" {
selectCfg := &conf.ConSelector{
ConnSelectorStr: selId,
}
if err := selectCfg.Init(); err != nil {
return nil, err
}
cf, err := selectCfg.ReadCfgFromYaml()
if err != nil {
return nil, err
}
props = cf
}
c, err := validateConfig(props)
c, err := ValidateConfig(props)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -180,7 +166,9 @@ func CreateClient(ctx api.StreamContext, selId string, props map[string]any) (*C
}
ctx.GetLogger().Infof("new mqtt client created")
con.Client = cli
con.attach()
if len(selId) > 0 {
con.Attach()
}
return con, nil
}

Expand All @@ -193,7 +181,7 @@ func handleToken(token pahoMqtt.Token) error {
return nil
}

func validateConfig(props map[string]any) (*ConnectionConfig, error) {
func ValidateConfig(props map[string]any) (*ConnectionConfig, error) {
c := &ConnectionConfig{PVersion: "3.1.1"}
err := cast.MapToStruct(props, c)
if err != nil {
Expand Down Expand Up @@ -227,3 +215,28 @@ func validateConfig(props map[string]any) (*ConnectionConfig, error) {
func (conn *Connection) GetClientId() string {
return conn.selId
}

Yisaer marked this conversation as resolved.
Show resolved Hide resolved
func CreateAnonymousConnection(ctx api.StreamContext, props map[string]any) (*Connection, error) {
cli, err := CreateClient(ctx, "", props)
if err != nil {
return nil, err
}
return cli, nil
}

const (
dataSourceProp = "datasource"
Yisaer marked this conversation as resolved.
Show resolved Hide resolved
topicProp = "topic"
)

func getTopicFromProps(props map[string]any) (string, error) {
v, ok := props[topicProp]
if ok {
return v.(string), nil
}
v, ok = props[dataSourceProp]
if ok {
return v.(string), nil
}
return "", fmt.Errorf("topic or datasource not defined")
}
55 changes: 55 additions & 0 deletions internal/io/mqtt/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"testing"

"github.com/stretchr/testify/assert"
)

// NOTICE!!! Need to run a MQTT broker in localhost:1883 for this test or change the url to your broker
const url = "tcp://127.0.0.1:1883"

func TestValidate(t *testing.T) {
tests := []struct {
name string
props map[string]any
err string
}{
{
name: "No server",
props: map[string]any{
"server": "",
},
err: "missing server property",
},
{
name: "invalid protocol",
props: map[string]any{
"server": url,
"protocolVersion": "5.0",
},
err: "unsupported protocol version 5.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ValidateConfig(tt.props)
assert.Error(t, err)
assert.Equal(t, tt.err, err.Error())
})
}
}
Loading
Loading