From ac300a1de455196667f3dbd78236ee2234d8c0d0 Mon Sep 17 00:00:00 2001 From: Jonathan Halterman Date: Thu, 30 Jan 2025 18:16:02 -0800 Subject: [PATCH] Update GrpcInflightMethodLimiter to support interceptors --- server/limits.go | 31 +++++++++++++++++++++++++++++++ server/server.go | 15 ++++++++++++--- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/server/limits.go b/server/limits.go index 4a8651e32..e52898681 100644 --- a/server/limits.go +++ b/server/limits.go @@ -4,6 +4,7 @@ import ( "context" "strings" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/tap" @@ -19,6 +20,11 @@ type GrpcInflightMethodLimiter interface { // otherwise gRPC-server implementation-specific error will be returned to the client (codes.PermissionDenied in grpc@v1.55.0). RPCCallStarting(ctx context.Context, methodName string, md metadata.MD) (context.Context, error) + // RPCCallProcessing is called by a server interceptor, allowing request pre-processing or blocking to be performed. + // handler should propagate the req. + RPCCallProcessing(ctx context.Context, methodName string) (func(error), error) + + // RPCCallFinished is called when an RPC call is finished being handled. RPCCallFinished(ctx context.Context) } @@ -47,6 +53,31 @@ func (g *grpcInflightLimitCheck) TapHandle(ctx context.Context, info *tap.Info) return g.methodLimiter.RPCCallStarting(ctx, info.FullMethodName, info.Header) } +func (g *grpcInflightLimitCheck) UnaryServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + finish, err := g.methodLimiter.RPCCallProcessing(ctx, info.FullMethod) + if err != nil { + return nil, err + } + result, err := handler(ctx, req) + if finish != nil { + finish(err) + } + return result, err + +} + +func (g *grpcInflightLimitCheck) StreamServerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + finish, err := g.methodLimiter.RPCCallProcessing(ss.Context(), info.FullMethod) + if err != nil { + return err + } + err = handler(srv, ss) + if finish != nil { + finish(err) + } + return err +} + func (g *grpcInflightLimitCheck) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } diff --git a/server/server.go b/server/server.go index f6c7f997f..de16a4638 100644 --- a/server/server.go +++ b/server/server.go @@ -420,6 +420,13 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { PermitWithoutStream: cfg.GRPCServerPingWithoutStreamAllowed, } + var grpcServerLimit *grpcInflightLimitCheck + if cfg.GrpcMethodLimiter != nil { + grpcServerLimit = newGrpcInflightLimitCheck(cfg.GrpcMethodLimiter) + grpcMiddleware = append(grpcMiddleware, grpcServerLimit.UnaryServerInterceptor) + grpcStreamMiddleware = append(grpcStreamMiddleware, grpcServerLimit.StreamServerInterceptor) + } + grpcOptions := []grpc.ServerOption{ grpc.ChainUnaryInterceptor(grpcMiddleware...), grpc.ChainStreamInterceptor(grpcStreamMiddleware...), @@ -431,9 +438,11 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { grpc.NumStreamWorkers(uint32(cfg.GRPCServerNumWorkers)), } - if cfg.GrpcMethodLimiter != nil { - grpcServerLimit := newGrpcInflightLimitCheck(cfg.GrpcMethodLimiter) - grpcOptions = append(grpcOptions, grpc.InTapHandle(grpcServerLimit.TapHandle), grpc.StatsHandler(grpcServerLimit)) + if grpcServerLimit != nil { + grpcOptions = append(grpcOptions, + grpc.StatsHandler(grpcServerLimit), + grpc.InTapHandle(grpcServerLimit.TapHandle), + ) } if cfg.GRPCServerStatsTrackingEnabled {