Skip to content

Commit 7adba6e

Browse files
committed
[keyvault/azsecrets] make azsecrets.Client thread-safe
azsecrets.Client uses NewKeyVaultChallengePolicy. This policy is not goroutine-safe, violating the documented requirement that policies are goroutine-safe [1]. This leads to data races which are reported by Go's race detector. Fix NewKeyVaultChallengePolicy to be goroutine-safe using a mutex. This can lead to redundant preflight requests, but at least Go's race detector no longer complains. Test plan: $ cd sdk/security/keyvault/internal/ $ go test -race [1] https://learn.microsoft.com/en-us/azure/developer/go/azure-sdk-core-concepts
1 parent 90c29cc commit 7adba6e

File tree

3 files changed

+96
-6
lines changed

3 files changed

+96
-6
lines changed

sdk/security/keyvault/azsecrets/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
### Breaking Changes
88

99
### Bugs Fixed
10+
* Fixed data race when using Client from multiple goroutines concurrently.
1011

1112
### Other Changes
1213

sdk/security/keyvault/internal/challenge_policy.go

+31-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"net/http"
1414
"net/url"
1515
"strings"
16+
"sync"
1617

1718
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1819
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
@@ -30,7 +31,12 @@ type KeyVaultChallengePolicyOptions struct {
3031
type keyVaultAuthorizer struct {
3132
// tro is the policy's authentication parameters. These are discovered from an authentication challenge
3233
// elicited ahead of the first client request.
33-
tro policy.TokenRequestOptions
34+
//
35+
// Protected by troLock.
36+
tro policy.TokenRequestOptions
37+
// Lock protecting tro in case there are multiple concurrent initial requests.
38+
troLock sync.RWMutex
39+
3440
verifyChallengeResource bool
3541
}
3642

@@ -55,7 +61,8 @@ func NewKeyVaultChallengePolicy(cred azcore.TokenCredential, opts *KeyVaultChall
5561
}
5662

5763
func (k *keyVaultAuthorizer) authorize(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error {
58-
if len(k.tro.Scopes) == 0 || k.tro.TenantID == "" {
64+
tro := k.getTokenRequestOptions()
65+
if len(tro.Scopes) == 0 || tro.TenantID == "" {
5966
if body := req.Body(); body != nil {
6067
// We don't know the scope or tenant ID because we haven't seen a challenge yet. We elicit one now by sending
6168
// the request without authorization, first removing its body, if any. authorizeOnChallenge will reattach the
@@ -70,7 +77,7 @@ func (k *keyVaultAuthorizer) authorize(req *policy.Request, authNZ func(policy.T
7077
return nil
7178
}
7279
// else we know the auth parameters and can authorize the request as normal
73-
return authNZ(k.tro)
80+
return authNZ(tro)
7481
}
7582

7683
func (k *keyVaultAuthorizer) authorizeOnChallenge(req *policy.Request, res *http.Response, authNZ func(policy.TokenRequestOptions) error) error {
@@ -87,7 +94,7 @@ func (k *keyVaultAuthorizer) authorizeOnChallenge(req *policy.Request, res *http
8794
}
8895
}
8996
// authenticate with the parameters supplied by Key Vault, authorize the request, send it again
90-
return authNZ(k.tro)
97+
return authNZ(k.getTokenRequestOptions())
9198
}
9299

93100
// parses Tenant ID from auth challenge
@@ -126,7 +133,6 @@ func (k *keyVaultAuthorizer) updateTokenRequestOptions(resp *http.Response, req
126133
}
127134
}
128135

129-
k.tro.TenantID = parseTenant(vals["authorization"])
130136
scope := ""
131137
if v, ok := vals["scope"]; ok {
132138
scope = v
@@ -149,6 +155,25 @@ func (k *keyVaultAuthorizer) updateTokenRequestOptions(resp *http.Response, req
149155
if !strings.HasSuffix(scope, "/.default") {
150156
scope += "/.default"
151157
}
152-
k.tro.Scopes = []string{scope}
158+
k.setTokenRequestOptions(policy.TokenRequestOptions{
159+
TenantID: parseTenant(vals["authorization"]),
160+
Scopes: []string{scope},
161+
})
153162
return nil
154163
}
164+
165+
// Returns a (possibly-zero) copy of TokenRequestOptions.
166+
//
167+
// The returned value's Scopes and other fields must not be modified.
168+
func (k *keyVaultAuthorizer) getTokenRequestOptions() policy.TokenRequestOptions {
169+
k.troLock.RLock()
170+
defer k.troLock.RUnlock()
171+
return k.tro // Copy.
172+
}
173+
174+
// After calling this function, tro.Scopes and other fields must not be modified.
175+
func (k *keyVaultAuthorizer) setTokenRequestOptions(tro policy.TokenRequestOptions) {
176+
k.troLock.Lock()
177+
defer k.troLock.Unlock()
178+
k.tro = tro // Copy.
179+
}

sdk/security/keyvault/internal/challenge_policy_test.go

+64
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ import (
1111
"context"
1212
"fmt"
1313
"net/http"
14+
"net/http/httptest"
1415
"strings"
16+
"sync"
17+
"sync/atomic"
1518
"testing"
1619
"time"
1720

@@ -355,3 +358,64 @@ func TestParseTenant(t *testing.T) {
355358
actual = parseTenant(sampleURL)
356359
require.Equal(t, expected, actual, "tenant was not properly parsed, got %s, expected %s", actual, expected)
357360
}
361+
362+
func TestChallengePolicy_ConcurrentRequests(t *testing.T) {
363+
concurrentRequestCount := 3
364+
365+
serverAuthenticateRequests := atomic.Int32{}
366+
serverAuthenticatedRequests := atomic.Int32{}
367+
var srv *httptest.Server
368+
srv = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
369+
authz := r.Header.Values("Authorization")
370+
if len(authz) == 0 {
371+
// Initial request without Authorization header. Send a
372+
// challenge response to the client.
373+
serverAuthenticateRequests.Add(1)
374+
resource := srv.URL
375+
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Bearer authorization="https://login.microsoftonline.com/{tenant}", resource="%s"`, resource))
376+
w.WriteHeader(401)
377+
} else {
378+
// Authenticated request.
379+
serverAuthenticatedRequests.Add(1)
380+
if len(authz) != 1 || authz[0] != "Bearer ***" {
381+
t.Errorf(`unexpected Authorization "%s"`, authz)
382+
}
383+
// Return nothing.
384+
w.WriteHeader(200)
385+
}
386+
}))
387+
defer srv.Close()
388+
srv.StartTLS()
389+
390+
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
391+
return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
392+
})
393+
p := NewKeyVaultChallengePolicy(cred, &KeyVaultChallengePolicyOptions{
394+
// Challenge resource verification will always fail because we
395+
// use local IPs instead of domain names and subdomains in this
396+
// test.
397+
DisableChallengeResourceVerification: true,
398+
})
399+
pl := runtime.NewPipeline("", "",
400+
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
401+
&policy.ClientOptions{Transport: srv.Client()},
402+
)
403+
404+
wg := sync.WaitGroup{}
405+
for i := 0; i < concurrentRequestCount; i += 1 {
406+
go (func() {
407+
defer wg.Done()
408+
req, err := runtime.NewRequest(context.Background(), "GET", srv.URL)
409+
require.NoError(t, err)
410+
res, err := pl.Do(req)
411+
require.NoError(t, err)
412+
defer res.Body.Close()
413+
})()
414+
wg.Add(1)
415+
}
416+
wg.Wait()
417+
418+
require.GreaterOrEqual(t, int(serverAuthenticateRequests.Load()), 1, "client should have sent at least one preflight request")
419+
require.LessOrEqual(t, int(serverAuthenticateRequests.Load()), concurrentRequestCount, "client should have sent no more preflight requests than client requests")
420+
require.EqualValues(t, concurrentRequestCount, serverAuthenticatedRequests.Load(), "client preflight request count should equal server preflight request count")
421+
}

0 commit comments

Comments
 (0)