From 440072f917b4d767e2d4b7ad0a9c73b358d2818b Mon Sep 17 00:00:00 2001 From: warber <72415058+warber@users.noreply.github.com> Date: Tue, 21 Jan 2025 19:07:07 +0100 Subject: [PATCH] fix: internal provider comparison causing race conditions in tests (#312) fix: internal provider comparison causing race conditions in tests Signed-off-by: Bernd Warmuth Co-authored-by: Todd Baert --- openfeature/event_executor.go | 41 ++++----------------- openfeature/reference.go | 29 +++++++++++++++ openfeature/reference_test.go | 69 +++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 33 deletions(-) create mode 100644 openfeature/reference.go create mode 100644 openfeature/reference_test.go diff --git a/openfeature/event_executor.go b/openfeature/event_executor.go index 9eea9220..ff381e44 100644 --- a/openfeature/event_executor.go +++ b/openfeature/event_executor.go @@ -2,7 +2,6 @@ package openfeature import ( "fmt" - "reflect" "sync" "time" @@ -68,13 +67,6 @@ type eventPayload struct { handler FeatureProvider } -// providerReference is a helper struct to store FeatureProvider with EventHandler capability along with their -// shutdown semaphore -type providerReference struct { - featureProvider FeatureProvider - shutdownSemaphore chan interface{} -} - // AddHandler adds an API(global) level handler func (e *eventExecutor) AddHandler(t EventType, c EventCallback) { e.mu.Lock() @@ -217,14 +209,7 @@ func (e *eventExecutor) registerDefaultProvider(provider FeatureProvider) error e.mu.Lock() defer e.mu.Unlock() - // register shutdown semaphore for new default provider - sem := make(chan interface{}) - - newProvider := providerReference{ - featureProvider: provider, - shutdownSemaphore: sem, - } - + newProvider := newProviderRef(provider) oldProvider := e.defaultProviderReference e.defaultProviderReference = newProvider @@ -235,14 +220,7 @@ func (e *eventExecutor) registerDefaultProvider(provider FeatureProvider) error func (e *eventExecutor) registerNamedEventingProvider(associatedClient string, provider FeatureProvider) error { e.mu.Lock() defer e.mu.Unlock() - - // register shutdown semaphore for new named provider - sem := make(chan interface{}) - - newProvider := providerReference{ - featureProvider: provider, - shutdownSemaphore: sem, - } + newProvider := newProviderRef(provider) oldProvider := e.namedProviderReference[associatedClient] e.namedProviderReference[associatedClient] = newProvider @@ -288,7 +266,7 @@ func (e *eventExecutor) startListeningAndShutdownOld(newProvider providerReferen // drop from active references for i, r := range e.activeSubscriptions { - if reflect.DeepEqual(oldReference.featureProvider, r.featureProvider) { + if oldReference.equals(r) { e.activeSubscriptions = append(e.activeSubscriptions[:i], e.activeSubscriptions[i+1:]...) } } @@ -332,8 +310,7 @@ func (e *eventExecutor) triggerEvent(event Event, handler FeatureProvider) { // then run client handlers for domain, reference := range e.namedProviderReference { - if !reflect.DeepEqual(reference.featureProvider, handler) { - // unassociated client, continue to next + if !reference.equals(newProviderRef(handler)) { continue } @@ -343,7 +320,7 @@ func (e *eventExecutor) triggerEvent(event Event, handler FeatureProvider) { } } - if !reflect.DeepEqual(e.defaultProviderReference.featureProvider, handler) { + if !e.defaultProviderReference.equals(newProviderRef(handler)) { return } @@ -386,25 +363,23 @@ func (e *eventExecutor) executeHandler(f func(details EventDetails), event Event // isRunning is a helper till we bump to the latest go version with slices.contains support func isRunning(provider providerReference, activeProviders []providerReference) bool { for _, activeProvider := range activeProviders { - if reflect.DeepEqual(activeProvider.featureProvider, provider.featureProvider) { + if activeProvider.equals(provider) { return true } } - return false } // isRunning is a helper to check if given provider is already in use func isBound(provider providerReference, defaultProvider providerReference, namedProviders []providerReference) bool { - if reflect.DeepEqual(provider.featureProvider, defaultProvider.featureProvider) { + if provider.equals(defaultProvider) { return true } for _, namedProvider := range namedProviders { - if reflect.DeepEqual(provider.featureProvider, namedProvider.featureProvider) { + if provider.equals(namedProvider) { return true } } - return false } diff --git a/openfeature/reference.go b/openfeature/reference.go new file mode 100644 index 00000000..7c6dff3a --- /dev/null +++ b/openfeature/reference.go @@ -0,0 +1,29 @@ +package openfeature + +import ( + "reflect" +) + +// newProviderRef creates a new providerReference instance that wraps around a FeatureProvider implementation +func newProviderRef(provider FeatureProvider) providerReference { + return providerReference{ + featureProvider: provider, + kind: reflect.TypeOf(provider).Kind(), + shutdownSemaphore: make(chan interface{}), + } +} + +// providerReference is a helper struct to store FeatureProvider along with their +// shutdown semaphore +type providerReference struct { + featureProvider FeatureProvider + kind reflect.Kind + shutdownSemaphore chan interface{} +} + +func (pr providerReference) equals(other providerReference) bool { + if pr.kind == reflect.Ptr && other.kind == reflect.Ptr { + return pr.featureProvider == other.featureProvider + } + return reflect.DeepEqual(pr.featureProvider, other.featureProvider) +} diff --git a/openfeature/reference_test.go b/openfeature/reference_test.go new file mode 100644 index 00000000..0f5462fc --- /dev/null +++ b/openfeature/reference_test.go @@ -0,0 +1,69 @@ +package openfeature + +import ( + "testing" +) + +func TestProviderReferenceEquals(t *testing.T) { + + type myProvider struct { + NoopProvider + field string + } + + p1 := myProvider{} + p2 := myProvider{} + + tests := []struct { + name string + pr1 providerReference + pr2 providerReference + expected bool + }{ + + { + name: "both pointers, different instances", + pr1: newProviderRef(&p1), + pr2: newProviderRef(&p2), + expected: false, + }, + { + name: "both pointers, same instance", + pr1: newProviderRef(&p1), + pr2: newProviderRef(&p1), + expected: true, + }, + { + name: "different pointers, different instance", + pr1: newProviderRef(p1), + pr2: newProviderRef(&p1), + expected: false, + }, + { + name: "no pointers, same instance", + pr1: newProviderRef(p1), + pr2: newProviderRef(p1), + expected: true, + }, + { + name: "no pointers, different equal instances", + pr1: newProviderRef(myProvider{field: "A"}), + pr2: newProviderRef(myProvider{field: "A"}), + expected: true, + }, + { + name: "no pointers, different not equal instances", + pr1: newProviderRef(myProvider{field: "A"}), + pr2: newProviderRef(myProvider{field: "B"}), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.pr1.equals(tt.pr2); got != tt.expected { + t.Errorf("providerReference.equals() = %v, want %v", got, tt.expected) + } + }) + } +}