Skip to content

Commit

Permalink
introduce context.Context to azcache package
Browse files Browse the repository at this point in the history
Signed-off-by: Fan Shang Xiang <[email protected]>
  • Loading branch information
MartinForReal authored and k8s-infra-cherrypick-robot committed Oct 15, 2024
1 parent dc4c70d commit 2f09dfb
Show file tree
Hide file tree
Showing 64 changed files with 974 additions and 974 deletions.
27 changes: 14 additions & 13 deletions pkg/cache/azure_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package cache

import (
"context"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -44,7 +45,7 @@ const (
)

// GetFunc defines a getter function for timedCache.
type GetFunc func(key string) (interface{}, error)
type GetFunc func(ctx context.Context, key string) (interface{}, error)

// AzureCacheEntry is the internal structure stores inside TTLStore.
type AzureCacheEntry struct {
Expand All @@ -64,8 +65,8 @@ func cacheKeyFunc(obj interface{}) (string, error) {

// Resource operations
type Resource interface {
Get(key string, crt AzureCacheReadType) (interface{}, error)
GetWithDeepCopy(key string, crt AzureCacheReadType) (interface{}, error)
Get(ctx context.Context, key string, crt AzureCacheReadType) (interface{}, error)
GetWithDeepCopy(ctx context.Context, key string, crt AzureCacheReadType) (interface{}, error)
Delete(key string) error
Set(key string, data interface{})
Update(key string, data interface{})
Expand Down Expand Up @@ -151,26 +152,26 @@ func (t *TimedCache) getInternal(key string) (*AzureCacheEntry, error) {
}

// Get returns the requested item by key.
func (t *TimedCache) Get(key string, crt AzureCacheReadType) (interface{}, error) {
return t.get(key, crt)
func (t *TimedCache) Get(ctx context.Context, key string, crt AzureCacheReadType) (interface{}, error) {
return t.get(ctx, key, crt)
}

func (c *ResourceProvider) Get(key string, _ AzureCacheReadType) (interface{}, error) {
return c.Getter(key)
func (c *ResourceProvider) Get(ctx context.Context, key string, _ AzureCacheReadType) (interface{}, error) {
return c.Getter(ctx, key)
}

// Get returns the requested item by key with deep copy.
func (t *TimedCache) GetWithDeepCopy(key string, crt AzureCacheReadType) (interface{}, error) {
data, err := t.get(key, crt)
func (t *TimedCache) GetWithDeepCopy(ctx context.Context, key string, crt AzureCacheReadType) (interface{}, error) {
data, err := t.get(ctx, key, crt)
copied := deepcopy.Copy(data)
return copied, err
}

func (c *ResourceProvider) GetWithDeepCopy(key string, _ AzureCacheReadType) (interface{}, error) {
return c.Getter(key)
func (c *ResourceProvider) GetWithDeepCopy(ctx context.Context, key string, _ AzureCacheReadType) (interface{}, error) {
return c.Getter(ctx, key)
}

func (t *TimedCache) get(key string, crt AzureCacheReadType) (interface{}, error) {
func (t *TimedCache) get(ctx context.Context, key string, crt AzureCacheReadType) (interface{}, error) {
entry, err := t.getInternal(key)
if err != nil {
return nil, err
Expand All @@ -193,7 +194,7 @@ func (t *TimedCache) get(key string, crt AzureCacheReadType) (interface{}, error
// Data is not cached yet, cache data is expired or requested force refresh
// cache it by getter. entry is locked before getting to ensure concurrent
// gets don't result in multiple ARM calls.
data, err := t.resourceProvider.Get(key, CacheReadTypeDefault /* not matter */)
data, err := t.resourceProvider.Get(ctx, key, CacheReadTypeDefault /* not matter */)
if err != nil {
return nil, err
}
Expand Down
34 changes: 17 additions & 17 deletions pkg/cache/azure_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ type fakeDataSource struct {
data sync.Map
}

func (fake *fakeDataSource) get(key string) (interface{}, error) {
func (fake *fakeDataSource) get(ctx context.Context, key string) (interface{}, error) {
if !fake.sem.TryAcquire(1) {
_ = fake.sem.Acquire(context.TODO(), 1)
_ = fake.sem.Acquire(ctx, 1)
fake.concurrent = true
}
defer fake.sem.Release(1)
Expand Down Expand Up @@ -116,21 +116,21 @@ func TestCacheGet(t *testing.T) {
for _, c := range cases {
dataSource, cache := newFakeCache(t)
dataSource.set(c.data)
val, err := cache.GetWithDeepCopy(c.key, CacheReadTypeDefault)
val, err := cache.GetWithDeepCopy(context.TODO(), c.key, CacheReadTypeDefault)
assert.NoError(t, err, c.name)
assert.Equal(t, c.expected, val, c.name)
}
}

func TestCacheGetError(t *testing.T) {
getError := fmt.Errorf("getError")
getter := func(_ string) (interface{}, error) {
getter := func(_ context.Context, _ string) (interface{}, error) {
return nil, getError
}
cache, err := NewTimedCache(fakeCacheTTL, getter, false)
assert.NoError(t, err)

val, err := cache.GetWithDeepCopy("key", CacheReadTypeDefault)
val, err := cache.GetWithDeepCopy(context.TODO(), "key", CacheReadTypeDefault)
assert.Error(t, err)
assert.Equal(t, getError, err)
assert.Nil(t, val)
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestCacheGetWithDeepCopy(t *testing.T) {
dataSource, cache := newFakeCache(t)
dataSource.set(c.data)
cache.Set(c.key, valFake)
val, err := cache.GetWithDeepCopy(c.key, CacheReadTypeDefault)
val, err := cache.GetWithDeepCopy(context.TODO(), c.key, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, c.expected, val.(*fakeDataObj).Data)

Expand All @@ -178,13 +178,13 @@ func TestCacheDelete(t *testing.T) {
dataSource, cache := newFakeCache(t)
dataSource.set(data)

v, err := cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err := cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, val, v, "cache should get correct data")

dataSource.set(nil)
_ = cache.Delete(testKey)
v, err = cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err = cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, 1, dataSource.called)
assert.Equal(t, nil, v, "cache should get nil after data is removed")
Expand All @@ -198,13 +198,13 @@ func TestCacheExpired(t *testing.T) {
dataSource, cache := newFakeCache(t)
dataSource.set(data)

v, err := cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err := cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, 1, dataSource.called)
assert.Equal(t, val, v, "cache should get correct data")

time.Sleep(fakeCacheTTL)
v, err = cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err = cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, 2, dataSource.called)
assert.Equal(t, val, v, "cache should get correct data even after expired")
Expand All @@ -218,13 +218,13 @@ func TestCacheAllowUnsafeRead(t *testing.T) {
dataSource, cache := newFakeCache(t)
dataSource.set(data)

v, err := cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err := cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, 1, dataSource.called)
assert.Equal(t, val, v, "cache should get correct data")

time.Sleep(fakeCacheTTL)
v, err = cache.GetWithDeepCopy(testKey, CacheReadTypeUnsafe)
v, err = cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeUnsafe)
assert.NoError(t, err)
assert.Equal(t, 1, dataSource.called)
assert.Equal(t, val, v, "cache should return expired as allow unsafe read is allowed")
Expand All @@ -244,10 +244,10 @@ func TestCacheNoConcurrentGet(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
_, _ = cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
}()
}
v, err := cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err := cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
wg.Wait()
assert.NoError(t, err)
assert.Equal(t, 1, dataSource.called)
Expand All @@ -268,7 +268,7 @@ func TestCacheNoConcurrentGetWithUpdate(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.Get(testKey, CacheReadTypeForceRefresh)
_, _ = cache.Get(context.TODO(), testKey, CacheReadTypeForceRefresh)
}()
}
goUpdate := func() {
Expand Down Expand Up @@ -301,12 +301,12 @@ func TestCacheForceRefresh(t *testing.T) {
dataSource, cache := newFakeCache(t)
dataSource.set(data)

v, err := cache.GetWithDeepCopy(testKey, CacheReadTypeDefault)
v, err := cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeDefault)
assert.NoError(t, err)
assert.Equal(t, 1, dataSource.called)
assert.Equal(t, val, v, "cache should get correct data")

v, err = cache.GetWithDeepCopy(testKey, CacheReadTypeForceRefresh)
v, err = cache.GetWithDeepCopy(context.TODO(), testKey, CacheReadTypeForceRefresh)
assert.NoError(t, err)
assert.Equal(t, 2, dataSource.called)
assert.Equal(t, val, v, "should refetch unexpired data as forced refresh")
Expand Down
4 changes: 2 additions & 2 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ func (np *IMDSNodeProvider) GetZone(ctx context.Context, _ types.NodeName) (clou
}

// GetPlatformSubFaultDomain returns the PlatformSubFaultDomain from IMDS if set.
func (np *IMDSNodeProvider) GetPlatformSubFaultDomain() (string, error) {
return np.azure.GetPlatformSubFaultDomain()
func (np *IMDSNodeProvider) GetPlatformSubFaultDomain(ctx context.Context) (string, error) {
return np.azure.GetPlatformSubFaultDomain(ctx)
}
2 changes: 1 addition & 1 deletion pkg/node/nodearm.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ func (np *ARMNodeProvider) GetZone(ctx context.Context, name types.NodeName) (cl
}

// GetPlatformSubFaultDomain returns the PlatformSubFaultDomain from IMDS if set.
func (np *ARMNodeProvider) GetPlatformSubFaultDomain() (string, error) {
func (np *ARMNodeProvider) GetPlatformSubFaultDomain(_ context.Context) (string, error) {
return "", nil
}
8 changes: 4 additions & 4 deletions pkg/nodeipam/ipam/cloud_cidr_allocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func NewCloudCIDRAllocator(
klog.Warningf("NewCloudCIDRAllocator: failed when trying to read the node mask size on node %s: no provider ID", node.Name)
continue
}
err := ca.updateNodeSubnetMaskSizes(node.Name, node.Spec.ProviderID)
err := ca.updateNodeSubnetMaskSizes(context.Background(), node.Name, node.Spec.ProviderID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -223,15 +223,15 @@ func (ca *cloudCIDRAllocator) updateMaxSubnetMaskSizes() {
}

// updateNodeSubnetMaskSizes gets the node's VMSS/VMAS, reads the mask size tag on it and updates them into the map
func (ca *cloudCIDRAllocator) updateNodeSubnetMaskSizes(nodeName, providerID string) error {
func (ca *cloudCIDRAllocator) updateNodeSubnetMaskSizes(ctx context.Context, nodeName, providerID string) error {
ca.lock.Lock()
defer ca.lock.Unlock()

if providerID == "" {
klog.Warningf("updateNodeSubnetMaskSizes(%s): empty providerID", providerID)
}

ipv4Mask, ipv6Mask, err := ca.cloud.VMSet.GetNodeCIDRMasksByProviderID(providerID)
ipv4Mask, ipv6Mask, err := ca.cloud.VMSet.GetNodeCIDRMasksByProviderID(ctx, providerID)
if err != nil {
klog.Warningf("updateNodeSubnetMaskSizes(%s): cannot get node subnet mask size by providerID: %v", providerID, err)
}
Expand Down Expand Up @@ -358,7 +358,7 @@ func (ca *cloudCIDRAllocator) AllocateOrOccupyCIDR(node *v1.Node) error {
return nil
}

err := ca.updateNodeSubnetMaskSizes(node.Name, node.Spec.ProviderID)
err := ca.updateNodeSubnetMaskSizes(context.Background(), node.Name, node.Spec.ProviderID)
if err != nil {
klog.Errorf("AllocateOrOccupyCIDR(%s): failed to update node subnet mask sizes: %v", node.Name, err)
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/nodeipam/ipam/cloud_cidr_allocator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func TestUpdateNodeSubnetMaskSizes(t *testing.T) {
nodeNameSubnetMaskSizesMap: make(map[string][]int),
}

err = ca.updateNodeSubnetMaskSizes("vmss-0", tc.providerID)
err = ca.updateNodeSubnetMaskSizes(context.Background(), "vmss-0", tc.providerID)
assert.Equal(t, tc.expectedErr, err)
assert.Equal(t, tc.expectedNodeNameSubnetMaskSizesMap, ca.nodeNameSubnetMaskSizesMap)
})
Expand Down
Loading

0 comments on commit 2f09dfb

Please sign in to comment.