Skip to content

Commit

Permalink
Update GrpcInflightMethodLimiter to support interceptors
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalterman committed Jan 31, 2025
1 parent 50cd2e2 commit ac300a1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
31 changes: 31 additions & 0 deletions server/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"strings"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/tap"
Expand All @@ -19,6 +20,11 @@ type GrpcInflightMethodLimiter interface {
// otherwise gRPC-server implementation-specific error will be returned to the client (codes.PermissionDenied in [email protected]).
RPCCallStarting(ctx context.Context, methodName string, md metadata.MD) (context.Context, error)

// RPCCallProcessing is called by a server interceptor, allowing request pre-processing or blocking to be performed.
// handler should propagate the req.
RPCCallProcessing(ctx context.Context, methodName string) (func(error), error)

// RPCCallFinished is called when an RPC call is finished being handled.
RPCCallFinished(ctx context.Context)
}

Expand Down Expand Up @@ -47,6 +53,31 @@ func (g *grpcInflightLimitCheck) TapHandle(ctx context.Context, info *tap.Info)
return g.methodLimiter.RPCCallStarting(ctx, info.FullMethodName, info.Header)
}

func (g *grpcInflightLimitCheck) UnaryServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
finish, err := g.methodLimiter.RPCCallProcessing(ctx, info.FullMethod)
if err != nil {
return nil, err
}
result, err := handler(ctx, req)
if finish != nil {
finish(err)
}
return result, err

}

func (g *grpcInflightLimitCheck) StreamServerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
finish, err := g.methodLimiter.RPCCallProcessing(ss.Context(), info.FullMethod)
if err != nil {
return err
}
err = handler(srv, ss)
if finish != nil {
finish(err)
}
return err
}

func (g *grpcInflightLimitCheck) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
return ctx
}
Expand Down
15 changes: 12 additions & 3 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,13 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) {
PermitWithoutStream: cfg.GRPCServerPingWithoutStreamAllowed,
}

var grpcServerLimit *grpcInflightLimitCheck
if cfg.GrpcMethodLimiter != nil {
grpcServerLimit = newGrpcInflightLimitCheck(cfg.GrpcMethodLimiter)
grpcMiddleware = append(grpcMiddleware, grpcServerLimit.UnaryServerInterceptor)
grpcStreamMiddleware = append(grpcStreamMiddleware, grpcServerLimit.StreamServerInterceptor)
}

grpcOptions := []grpc.ServerOption{
grpc.ChainUnaryInterceptor(grpcMiddleware...),
grpc.ChainStreamInterceptor(grpcStreamMiddleware...),
Expand All @@ -431,9 +438,11 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) {
grpc.NumStreamWorkers(uint32(cfg.GRPCServerNumWorkers)),
}

if cfg.GrpcMethodLimiter != nil {
grpcServerLimit := newGrpcInflightLimitCheck(cfg.GrpcMethodLimiter)
grpcOptions = append(grpcOptions, grpc.InTapHandle(grpcServerLimit.TapHandle), grpc.StatsHandler(grpcServerLimit))
if grpcServerLimit != nil {
grpcOptions = append(grpcOptions,
grpc.StatsHandler(grpcServerLimit),
grpc.InTapHandle(grpcServerLimit.TapHandle),
)
}

if cfg.GRPCServerStatsTrackingEnabled {
Expand Down

0 comments on commit ac300a1

Please sign in to comment.