Skip to content

Commit

Permalink
RSDK-9760 - Tunneling over grpc (viamrobotics#4745)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheukt authored Jan 28, 2025
1 parent 24c495b commit dccdac0
Show file tree
Hide file tree
Showing 9 changed files with 683 additions and 6 deletions.
8 changes: 8 additions & 0 deletions examples/tunnel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Tunnel
This example shows how to use the traffic tunneling feature in the viam-server


### Running
Run this example with `go run tunnel.go -addr {address of machine} -api-key {api key to use to connect to machine} -api-key-id {api key id to use to connect to machine} -dest {destination address to tunnel to (default 3389)} -src {source address to listen on (default 9090)}`

API key and API key id can be left blank if the machine is insecure.
122 changes: 122 additions & 0 deletions examples/tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// main is an example of tunneling traffic over grpc
package main

import (
"context"
"flag"
"net"
"strconv"
"sync"

"go.viam.com/utils/rpc"

"go.viam.com/rdk/logging"
"go.viam.com/rdk/robot/client"
)

var (
address = ""

defaultSourcePort = 9090
defaultDestinationPort = 3389
)

func main() {
logger := logging.NewDebugLogger("client")
var src int
flag.IntVar(&src, "src", defaultSourcePort, "source address to listen on")

var dest int
flag.IntVar(&dest, "dest", defaultDestinationPort, "destination address to tunnel to")

var addr string
flag.StringVar(&addr, "addr", address, "machine name to connect to")

var apiKey string
flag.StringVar(&apiKey, "api-key", apiKey, "api key to use to connect to machine")

var apiKeyID string
flag.StringVar(&apiKeyID, "api-key-id", apiKeyID, "api key id to use to connect to machine")

flag.Parse()

if addr == "" {
logger.Error("please enter an address with flag --addr")
return
}

logger.Infow("starting tunnel", "source address", src, "destination address", dest)
ctx := context.Background()

opts := []client.RobotClientOption{
client.WithRefreshEvery(0),
client.WithCheckConnectedEvery(0),
client.WithDisableSessions(),
}

if apiKey != "" && apiKeyID != "" {
opts = append(opts,
client.WithDialOptions(rpc.WithEntityCredentials(
apiKeyID,
rpc.Credentials{
Type: rpc.CredentialsTypeAPIKey,
Payload: apiKey,
}),
),
)
} else {
opts = append(opts,
client.WithDialOptions(
rpc.WithInsecure(),
rpc.WithDisableDirectGRPC(),
),
)
}
machine, err := client.New(ctx, addr, logger, opts...)
if err != nil {
logger.Error(err)
return
}

defer func() {
if err := machine.Close(context.Background()); err != nil {
logger.CErrorw(ctx, "error closing machine", "err", err)
}
}()
tunnelTraffic(ctx, machine, src, dest, logger)
}

func tunnelTraffic(ctx context.Context, machine *client.RobotClient, src, dest int, logger logging.Logger) {
// create listener
li, err := net.Listen("tcp", net.JoinHostPort("localhost", strconv.Itoa(src)))
if err != nil {
logger.CErrorw(ctx, "failed to create listener", "err", err)
return
}
defer func() {
if err := li.Close(); err != nil {
logger.CErrorw(ctx, "error closing listener", "err", err)
}
}()

var wg sync.WaitGroup
for {
if ctx.Err() != nil {
break
}
conn, err := li.Accept()
if err != nil {
logger.CErrorw(ctx, "failed to accept conn: %v\n", err)
}
wg.Add(1)
go func() {
defer wg.Done()
// call tunnel once per connection, the connection passed in will be closed
// by Tunnel.
if err := machine.Tunnel(ctx, conn, dest); err != nil {
logger.CError(ctx, err)
}
}()
}
wg.Wait()
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ require (
go.uber.org/atomic v1.11.0
go.uber.org/multierr v1.11.0
go.uber.org/zap v1.27.0
go.viam.com/api v0.1.383
go.viam.com/api v0.1.387
go.viam.com/test v1.2.4
go.viam.com/utils v0.1.128
goji.io v2.0.2+incompatible
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1513,8 +1513,8 @@ go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.viam.com/api v0.1.383 h1:HE9EpUWoyDBYJLTVLo29f66oCjzH31V1YJ02tlzCtyo=
go.viam.com/api v0.1.383/go.mod h1:g5eipXHNm0rQmW7DWya6avKcmzoypLmxnMlAaIsE5Ls=
go.viam.com/api v0.1.387 h1:nxWF+dO+z2mwTpTkcs8mr8L0ic4ZwcqOAumBtHnvqsc=
go.viam.com/api v0.1.387/go.mod h1:g5eipXHNm0rQmW7DWya6avKcmzoypLmxnMlAaIsE5Ls=
go.viam.com/test v1.2.4 h1:JYgZhsuGAQ8sL9jWkziAXN9VJJiKbjoi9BsO33TW3ug=
go.viam.com/test v1.2.4/go.mod h1:zI2xzosHdqXAJ/kFqcN+OIF78kQuTV2nIhGZ8EzvaJI=
go.viam.com/utils v0.1.128 h1:ScO0pWiwoYzsILcJudL4axCAGhNPmzszJeMoHhdx23s=
Expand Down
78 changes: 78 additions & 0 deletions robot/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"go.viam.com/rdk/robot/packages"
"go.viam.com/rdk/session"
"go.viam.com/rdk/spatialmath"
"go.viam.com/rdk/tunnel"
"go.viam.com/rdk/utils/contextutils"
)

Expand Down Expand Up @@ -1216,6 +1217,83 @@ func (rc *RobotClient) Version(ctx context.Context) (robot.VersionResponse, erro
return mVersion, nil
}

// Tunnel tunnels data to/from the read writer from/to the destination port on the server. This
// function will close the connection passed in as part of cleanup.
func (rc *RobotClient) Tunnel(ctx context.Context, conn io.ReadWriteCloser, dest int) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client, err := rc.client.Tunnel(ctx)
if err != nil {
return err
}

if err := client.Send(&pb.TunnelRequest{
DestinationPort: uint32(dest),
}); err != nil {
return err
}
rc.Logger().CInfow(ctx, "creating tunnel to server", "port", dest)
var (
wg sync.WaitGroup
readerSenderErr error

timerMu sync.Mutex
timer *time.Timer
)
connClosed := make(chan struct{})
rsDone := make(chan struct{})
wg.Add(1)
utils.PanicCapturingGo(func() {
defer func() {
// We communicate an end to the stream by calling CloseSend().
// Close the channel first so that network errors can be filtered
// and prevented in the RecvWriterLoop.
close(rsDone)
readerSenderErr = errors.Join(readerSenderErr, client.CloseSend())

// Schedule a task to cancel the context if we do not exit out of the recvWriterLoop within 5 seconds.
// This will close the client, meaning client.Recv() in the RecvWriterLoop will exit and return an error.
//
// NOTE(cheukt): This may cause DEBUG messages from WebRTC stating `no stream for id; discarding`
// to show up because the handler will have exited before we receive the last messages from the server.
// This is not an issue and is expected.
timerMu.Lock()
timer = time.AfterFunc(5*time.Second, cancel)
timerMu.Unlock()
wg.Done()
}()
// a max of 32kb will be sent per message (based on io.Copy's default buffer size)
sendFunc := func(data []byte) error { return client.Send(&pb.TunnelRequest{Data: data}) }
readerSenderErr = tunnel.ReaderSenderLoop(ctx, conn, sendFunc, connClosed, rc.logger.WithFields("loop", "reader/sender"))
})

recvFunc := func() ([]byte, error) {
resp, err := client.Recv()
if err != nil {
return nil, err
}
return resp.Data, nil
}
recvWriterErr := tunnel.RecvWriterLoop(ctx, recvFunc, conn, rsDone, rc.logger.WithFields("loop", "recv/writer"))
timerMu.Lock()
// cancel the timer if we've successfully returned from the RecvWriterLoop
if timer != nil {
timer.Stop()
}
timerMu.Unlock()

// We close the connection to unblock the reader/sender loop, which is not clean
// but there isn't a cleaner way to exit from the reader/sender loop.
// Close the channel first so that network errors can be filtered
// and prevented in the ReaderSenderLoop.
close(connClosed)
err = conn.Close()

wg.Wait()
rc.Logger().CInfow(ctx, "tunnel to server closed", "port", dest)
return errors.Join(err, readerSenderErr, recvWriterErr)
}

func unaryClientInterceptor() googlegrpc.UnaryClientInterceptor {
return func(
ctx context.Context,
Expand Down
61 changes: 58 additions & 3 deletions robot/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ package server
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"

"github.com/google/uuid"
"github.com/pkg/errors"
"go.uber.org/zap/zapcore"
commonpb "go.viam.com/api/common/v1"
pb "go.viam.com/api/robot/v1"
Expand All @@ -30,6 +33,7 @@ import (
"go.viam.com/rdk/resource"
"go.viam.com/rdk/robot"
"go.viam.com/rdk/session"
"go.viam.com/rdk/tunnel"
)

// logTSKey is the key used in conjunction with the timestamp of logs received
Expand All @@ -56,6 +60,57 @@ func New(robot robot.Robot) pb.RobotServiceServer {
func (s *Server) Close() {
}

// Tunnel tunnels traffic to/from the client from/to a specified port on the server.
func (s *Server) Tunnel(srv pb.RobotService_TunnelServer) error {
req, err := srv.Recv()
if err != nil {
return fmt.Errorf("failed to receive first message from stream: %w", err)
}

dest := strconv.Itoa(int(req.DestinationPort))
s.robot.Logger().CDebugw(srv.Context(), "dialing to destination port", "port", dest)

dialTimeout := 10 * time.Second
conn, err := net.DialTimeout("tcp", net.JoinHostPort("localhost", dest), dialTimeout)
if err != nil {
return fmt.Errorf("failed to dial to destination port %v: %w", dest, err)
}
s.robot.Logger().CInfow(srv.Context(), "successfully dialed to destination port, creating tunnel", "port", dest)

var (
wg sync.WaitGroup
readerSenderErr error
)
connClosed := make(chan struct{})
rsDone := make(chan struct{})
wg.Add(1)
utils.PanicCapturingGo(func() {
defer func() {
close(rsDone)
wg.Done()
}()
// a max of 32kb will be sent per message (based on io.Copy's default buffer size)
sendFunc := func(data []byte) error { return srv.Send(&pb.TunnelResponse{Data: data}) }
readerSenderErr = tunnel.ReaderSenderLoop(srv.Context(), conn, sendFunc, connClosed, s.robot.Logger().WithFields("loop", "reader/sender"))
})
recvFunc := func() ([]byte, error) {
req, err := srv.Recv()
if err != nil {
return nil, err
}
return req.Data, nil
}
recvWriterErr := tunnel.RecvWriterLoop(srv.Context(), recvFunc, conn, rsDone, s.robot.Logger().WithFields("loop", "recv/writer"))
// close the connection to unblock the read
// close the channel first so that network errors can be filtered
// and prevented in the ReaderSenderLoop.
close(connClosed)
err = conn.Close()
wg.Wait()
s.robot.Logger().CInfow(srv.Context(), "tunnel to client closed", "port", dest)
return errors.Join(err, readerSenderErr, recvWriterErr)
}

// GetOperations lists all running operations.
func (s *Server) GetOperations(ctx context.Context, req *pb.GetOperationsRequest) (*pb.GetOperationsResponse, error) {
me := operation.Get(ctx)
Expand Down Expand Up @@ -201,7 +256,7 @@ func (s *Server) DiscoverComponents(ctx context.Context, req *pb.DiscoverCompone
for _, discovery := range discoveries {
pbResults, err := vprotoutils.StructToStructPb(discovery.Results)
if err != nil {
return nil, errors.Wrapf(err, "unable to construct a structpb.Struct from discovery for %q", discovery.Query)
return nil, fmt.Errorf("unable to construct a structpb.Struct from discovery for %q: %w", discovery.Query, err)
}
extra, err := structpb.NewStruct(discovery.Query.Extra)
if err != nil {
Expand Down Expand Up @@ -399,7 +454,7 @@ func (s *Server) Log(ctx context.Context, req *pb.LogRequest) (*pb.LogResponse,
for _, fieldP := range log.Fields {
field, err := logging.FieldFromProto(fieldP)
if err != nil {
return nil, errors.Wrap(err, "error converting LogRequest log field from proto")
return nil, fmt.Errorf("error converting LogRequest log field from proto: %w", err)
}
fields = append(fields, field)
}
Expand Down
Loading

0 comments on commit dccdac0

Please sign in to comment.