diff --git a/gbus/worker.go b/gbus/worker.go index 645ec93..9e4edc4 100644 --- a/gbus/worker.go +++ b/gbus/worker.go @@ -102,9 +102,6 @@ func (worker *worker) createMessagesChannel(q amqp.Queue, consumerTag string) (< func (worker *worker) consumeMessages() { for msg := range worker.messages { - if msg.Body == nil || len(msg.Body) == 0 { - continue - } worker.processMessage(msg, false) } } @@ -112,9 +109,6 @@ func (worker *worker) consumeMessages() { func (worker *worker) consumeRPC() { for msg := range worker.rpcMessages { - if msg.Body == nil || len(msg.Body) == 0 { - continue - } worker.processMessage(msg, true) } } @@ -311,6 +305,17 @@ func (worker *worker) processMessage(delivery amqp.Delivery, isRPCreply bool) { _ = worker.ack(delivery) return } + + if delivery.Body == nil || len(delivery.Body) == 0 { + worker.log(). + WithFields( + logrus.Fields{"message-name": msgName}). + Warn("body is missing for message. Cannot invoke handlers.") + worker.span.LogFields(slog.String("grabbit", "no body found")) + // if there are handlers registered for this type of message, it's a bug and the message must be rejected. + _ = worker.reject(false, delivery) + return + } /* extract the bus message only after we are sure there are registered handlers since it includes deserializing the amqp payload which we want to avoid if no handlers are found diff --git a/tests/bus_test.go b/tests/bus_test.go index b1cb55e..e3da0ea 100644 --- a/tests/bus_test.go +++ b/tests/bus_test.go @@ -50,16 +50,14 @@ func TestSendCommand(t *testing.T) { if err != nil { t.Errorf("could not start bus for test error: %s", err.Error()) } + defer assertBusShutdown(b, t) err = b.Send(noopTraceContext(), testSvc1, gbus.NewBusMessage(cmd)) if err != nil { t.Errorf("could not send message error: %s", err.Error()) return } - - <-proceed - b.Shutdown() - + proceedOrTimeout(2, proceed, nil, t) } func TestReply(t *testing.T) { @@ -84,7 +82,7 @@ func TestReply(t *testing.T) { replyHandler := func(invocation gbus.Invocation, message *gbus.BusMessage) error { _, ok := message.Payload.(*Reply1) if !ok { - t.Errorf("message handler for reply message invoced with wrong message type\r\n%v", message) + t.Errorf("message handler for reply message invoked with wrong message type\r\n%v", message) } proceed <- true @@ -95,13 +93,13 @@ func TestReply(t *testing.T) { svc1.HandleMessage(reply, replyHandler) svc1.Start() - defer svc1.Shutdown() + defer assertBusShutdown(svc1, t) svc2.Start() - defer svc2.Shutdown() + defer assertBusShutdown(svc2, t) svc1.Send(noopTraceContext(), testSvc2, cmdBusMsg) - <-proceed + proceedOrTimeout(2, proceed, nil, t) } func TestPubSub(t *testing.T) { @@ -116,13 +114,12 @@ func TestPubSub(t *testing.T) { b.HandleEvent("test_exchange", "test_topic", event, eventHandler) b.Start() - defer b.Shutdown() + defer assertBusShutdown(b, t) err := b.Publish(noopTraceContext(), "test_exchange", "test_topic", gbus.NewBusMessage(event)) if err != nil { t.Fatal(err) } - <-proceed - + proceedOrTimeout(2, proceed, nil, t) } func TestSubscribingOnTopic(t *testing.T) { @@ -137,12 +134,12 @@ func TestSubscribingOnTopic(t *testing.T) { b.HandleEvent("test_exchange", "a.*.c", nil, eventHandler) b.Start() - defer b.Shutdown() + defer assertBusShutdown(b, t) err := b.Publish(noopTraceContext(), "test_exchange", "a.b.c", gbus.NewBusMessage(event)) if err != nil { t.Fatal(err) } - <-proceed + proceedOrTimeout(2, proceed, nil, t) } var ( @@ -167,7 +164,7 @@ func TestHandlerRetry(t *testing.T) { bus.HandleMessage(r1, handleRetry) bus.Start() - defer bus.Shutdown() + defer assertBusShutdown(bus, t) bus.Send(noopTraceContext(), testSvc1, cmd) <-handlerRetryProceed @@ -214,10 +211,10 @@ func TestRPC(t *testing.T) { svc1 := createNamedBusForTest(testSvc1) svc1.HandleMessage(c1, handler) svc1.Start() - defer svc1.Shutdown() + defer assertBusShutdown(svc1, t) svc2 := createNamedBusForTest(testSvc2) svc2.Start() - defer svc2.Shutdown() + defer assertBusShutdown(svc2, t) t.Log("Sending RPC") reply, _ = svc2.RPC(noopTraceContext(), testSvc1, cmd, reply, 5*time.Second) t.Log("Tested RPC") @@ -248,14 +245,15 @@ func TestDeadlettering(t *testing.T) { service1.HandleMessage(Command1{}, faultyHandler) deadletterSvc.Start() - defer deadletterSvc.Shutdown() + defer assertBusShutdown(deadletterSvc, t) service1.Start() - defer service1.Shutdown() + defer assertBusShutdown(service1, t) service1.Send(context.Background(), testSvc1, poison) service1.Send(context.Background(), testSvc1, gbus.NewBusMessage(Command1{})) - <-proceed + proceedOrTimeout(2, proceed, nil, t) + count, _ := metrics.GetRejectedMessagesValue() if count != 1 { t.Error("Should have one rejected message") @@ -290,13 +288,12 @@ func TestRawMessageHandling(t *testing.T) { svc1 := createNamedBusForTest(testSvc1) svc1.SetGlobalRawMessageHandler(handler) _ = svc1.Start() + defer assertBusShutdown(svc1, t) cmd1 := gbus.NewBusMessage(Command1{}) _ = svc1.Send(context.Background(), testSvc1, cmd1) - <-proceed - _ = svc1.Shutdown() - + proceedOrTimeout(2, proceed, nil, t) } func TestReturnDeadToQueue(t *testing.T) { @@ -330,18 +327,12 @@ func TestReturnDeadToQueue(t *testing.T) { service1.HandleMessage(Command1{}, faultyHandler) deadletterSvc.Start() - defer deadletterSvc.Shutdown() + defer assertBusShutdown(deadletterSvc, t) service1.Start() - defer service1.Shutdown() + defer assertBusShutdown(service1, t) service1.Send(context.Background(), testSvc1, poison) - - select { - case <-proceed: - fmt.Println("success") - case <-time.After(2 * time.Second): - t.Fatal("timeout, failed to resend dead message to queue") - } + proceedOrTimeout(2, proceed, nil, t) } func TestDeadLetterHandlerPanic(t *testing.T) { @@ -382,25 +373,20 @@ func TestDeadLetterHandlerPanic(t *testing.T) { } deadletterSvc.Start() - defer deadletterSvc.Shutdown() + defer assertBusShutdown(deadletterSvc, t) service1.Start() - defer service1.Shutdown() + defer assertBusShutdown(service1, t) service1.Send(context.Background(), testSvc1, poison) - select { - case <-proceed: + proceedOrTimeout(2, proceed, func() { count, _ := metrics.GetRejectedMessagesValue() //we expect only 1 rejcted meessage from the counter since rejected messages that get //requeued are not reported to the metric so the counter won't be increment when the message //in the dlq gets rejected as it is rejected with the requeue option set to true if count != 1 { - t.Errorf("Should have 1 rejected messages but was %v", count) } - case <-time.After(2 * time.Second): - t.Fatal("timeout, dlq failed to reject message after handler panicked") - } - + }, t) } func TestRegistrationAfterBusStarts(t *testing.T) { @@ -413,14 +399,14 @@ func TestRegistrationAfterBusStarts(t *testing.T) { return nil } b.Start() - defer b.Shutdown() + defer assertBusShutdown(b, t) b.HandleEvent("test_exchange", "test_topic", event, eventHandler) err := b.Publish(noopTraceContext(), "test_exchange", "test_topic", gbus.NewBusMessage(event)) if err != nil { t.Fatal(err) } - <-proceed + proceedOrTimeout(2, proceed, nil, t) } @@ -449,24 +435,20 @@ func TestOpenTracingReporting(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - err := b.Shutdown() - if err != nil { - t.Fatal(err) - } - }() + defer assertBusShutdown(b, t) err = b.Publish(ctx, "test_exchange", "test_topic", gbus.NewBusMessage(event)) if err != nil { t.Fatal(err) } - <-proceed - time.Sleep(2 * time.Second) - span.Finish() - spans := mockTracer.FinishedSpans() - if len(spans) < 2 { - t.Fatal("didn't send any traces in the code") - } + proceedOrTimeout(2, proceed, func() { + time.Sleep(2 * time.Second) + span.Finish() + spans := mockTracer.FinishedSpans() + if len(spans) < 2 { + t.Fatal("didn't send any traces in the code") + } + }, t) } func TestSendingPanic(t *testing.T) { @@ -476,12 +458,7 @@ func TestSendingPanic(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - err := b.Shutdown() - if err != nil { - t.Fatal(err) - } - }() + defer assertBusShutdown(b, t) defer func() { if p := recover(); p != nil { t.Fatal("expected not to have to recover this should be handled in grabbit", p) @@ -493,13 +470,145 @@ func TestSendingPanic(t *testing.T) { } } +func TestEmptyBody(t *testing.T) { + b := createNamedBusForTest(testSvc1) + proceed := make(chan bool) + b.SetGlobalRawMessageHandler(func(tx *sql.Tx, delivery *amqp.Delivery) error { + proceed <- true + return nil + }) + + err := b.Start() + if err != nil { + t.Errorf("could not start bus for test error: %s", err.Error()) + } + defer assertBusShutdown(b, t) + conn, err := amqp.Dial(connStr) + if err != nil { + t.Error("couldnt connect to rabbitmq") + } + defer conn.Close() + + ch, err := conn.Channel() + if err != nil { + t.Error("couldnt open rabbitmq channel for publishing") + } + defer ch.Close() + + cmd := amqp.Publishing{} + err = ch.Publish("", testSvc1, true, false, cmd) + if err != nil { + t.Error("couldnt send message on rabbitmq channel") + } + proceedOrTimeout(2, proceed, nil, t) +} + +func TestEmptyMessageInvokesDeadHanlder(t *testing.T) { + /* + test call for dead letter handler when a message with nil or len 0 body is consumed. the handler + should handle the message successfully. + */ + + b := createBusWithConfig(testSvc1, "grabbit-dead", true, true, + gbus.BusConfiguration{MaxRetryCount: 0, BaseRetryDuration: 0}) + + proceed := make(chan bool) + b.HandleDeadletter(func(tx *sql.Tx, delivery *amqp.Delivery) error { + proceed <- true + return nil + }) + + err := b.Start() + if err != nil { + t.Errorf("could not start bus for test error: %s", err.Error()) + } + defer assertBusShutdown(b, t) + + conn, err := amqp.Dial(connStr) + if err != nil { + t.Error("couldnt connect to rabbitmq") + } + defer conn.Close() + ch, err := conn.Channel() + if err != nil { + t.Error("couldnt open rabbitmq channel for publishing") + } + defer ch.Close() + + headersMap := make(map[string]interface{}) + headersMap["x-death"] = make([]interface{}, 0) + cmd := amqp.Publishing{Headers: headersMap} + err = ch.Publish("", testSvc1, true, false, cmd) + if err != nil { + t.Error("couldnt send message on rabbitmq channel") + } + proceedOrTimeout(2, proceed, nil, t) +} + +func TestFailHandlerInvokeOfMessageWithEmptyBody(t *testing.T) { + /* + The global and dead letter handlers can consume message with 0 or nil body but + "normal" handlers cannot. + If a "normal" handler is registered for this type of message, the bus must reject this message. + */ + metrics.ResetRejectedMessagesCounter() + b := createBusWithConfig(testSvc1, "grabbit-dead1", true, true, + gbus.BusConfiguration{MaxRetryCount: 0, BaseRetryDuration: 0}) + + proceed := make(chan bool) + b.HandleDeadletter(func(tx *sql.Tx, delivery *amqp.Delivery) error { + proceed <- true + return nil + }) + err := b.HandleMessage(Command1{}, func(invocation gbus.Invocation, message *gbus.BusMessage) error { + t.Error("handler invoked for non-grabbit message") + return nil + }) + if err != nil { + t.Errorf("could not register handler for bus %s", err.Error()) + } + + err = b.Start() + if err != nil { + t.Errorf("could not start bus for test error: %s", err.Error()) + } + defer assertBusShutdown(b, t) + + conn, err := amqp.Dial(connStr) + if err != nil { + t.Error("couldnt connect to rabbitmq") + } + defer conn.Close() + + ch, err := conn.Channel() + if err != nil { + t.Error("couldnt open rabbitmq channel for publishing") + } + defer ch.Close() + + headersMap := make(map[string]interface{}) + headersMap["x-msg-name"] = Command1{}.SchemaName() + cmd := amqp.Publishing{Headers: headersMap} + err = ch.Publish("", testSvc1, true, false, cmd) + if err != nil { + t.Error("couldnt send message on rabbitmq channel") + } + + proceedOrTimeout(2, proceed, func() { + count, _ := metrics.GetRejectedMessagesValue() + if count != 1 { + t.Error("Should have one rejected message") + } + }, t) +} + func TestHealthCheck(t *testing.T) { svc1 := createNamedBusForTest(testSvc1) err := svc1.Start() if err != nil { t.Error(err.Error()) } - defer svc1.Shutdown() + defer assertBusShutdown(svc1, t) health := svc1.GetHealth() fmt.Printf("%v", health) @@ -514,7 +623,7 @@ func TestSanitizingSvcName(t *testing.T) { if err != nil { t.Error(err.Error()) } - defer svc4.Shutdown() + defer assertBusShutdown(svc4, t) fmt.Println("succeeded sanitizing service name") } @@ -551,3 +660,21 @@ type panicPolicy struct { func (p panicPolicy) Apply(publishing *amqp.Publishing) { panic("vlad") } + +func assertBusShutdown(bus gbus.Bus, t *testing.T) { + err := bus.Shutdown() + if err != nil { + t.Fatal(err) + } +} + +func proceedOrTimeout(timeout time.Duration, p chan bool, onProceed func(), t *testing.T) { + select { + case <-p: + if onProceed != nil { + onProceed() + } + case <-time.After(timeout * time.Second): + t.Fatal("timeout") + } +}