From 5959eeb144147f38dfed7edcbd3751872b396cb4 Mon Sep 17 00:00:00 2001 From: Ilya Hontarau Date: Mon, 12 Aug 2024 09:59:07 +0200 Subject: [PATCH] Refactor auto scaler logic (#13) * extract autoscaler logic * add tests for autoscaler logic * gen mocks in test pkg * rename mock --- cmd/internal/handle.go | 111 +----------------- internal/auto_scaler.go | 141 +++++++++++++++++++++++ internal/auto_scaler_test.go | 122 ++++++++++++++++++++ internal/mock_controller_test.go | 188 +++++++++++++++++++++++++++++++ 4 files changed, 452 insertions(+), 110 deletions(-) create mode 100644 internal/auto_scaler.go create mode 100644 internal/auto_scaler_test.go create mode 100644 internal/mock_controller_test.go diff --git a/cmd/internal/handle.go b/cmd/internal/handle.go index 83cdbf9..3a4084a 100644 --- a/cmd/internal/handle.go +++ b/cmd/internal/handle.go @@ -3,7 +3,6 @@ package internal import ( "context" "fmt" - "time" "github.com/caarlos0/env/v9" "golang.org/x/exp/slog" @@ -17,117 +16,9 @@ func Handle(ctx context.Context, logger *slog.Logger) error { return fmt.Errorf("could not parse environment variables: %w", err) } - logger = logger.With( - "asg_arn", cfg.AutoscalingGroupARN, - "worker_pool_id", cfg.SpaceliftWorkerPoolID, - ) - controller, err := internal.NewController(ctx, &cfg) if err != nil { return fmt.Errorf("could not create controller: %w", err) } - - workerPool, err := controller.GetWorkerPool(ctx) - if err != nil { - return fmt.Errorf("could not get worker pool: %w", err) - } - - asg, err := controller.GetAutoscalingGroup(ctx) - if err != nil { - return fmt.Errorf("could not get autoscaling group: %w", err) - } - - state, err := internal.NewState(workerPool, asg) - if err != nil { - return fmt.Errorf("could not create state: %w", err) - } - - // Let's make sure that for each of the in-service instances we have a - // corresponding worker in Spacelift, or that we have "stray" machines. - if strayInstances := state.StrayInstances(); len(strayInstances) > 0 { - // There's a question of what to do with the "stray" machines. The - // decision will be made based on the creation timestamp. - instances, err := controller.DescribeInstances(ctx, strayInstances) - if err != nil { - return fmt.Errorf("could not list EC2 instances: %w", err) - } - - for _, instance := range instances { - logger = logger.With("instance_id", *instance.InstanceId) - instanceAge := time.Since(*instance.LaunchTime) - - logger = logger.With( - "launch_timestamp", instance.LaunchTime.Unix(), - "instance_age", instanceAge, - ) - - // If the machine was only created recently (say a generous window of 10 - // minutes), it is possible that it hasn't managed to register itself with - // Spacelift yet. But if it's been around for a while we will want to kill - // it and remove it from the ASG. - if instanceAge > 10*time.Minute { - logger.Warn("instance has no corresponding worker in Spacelift, removing from the ASG") - - if err := controller.KillInstance(ctx, *instance.InstanceId); err != nil { - return fmt.Errorf("could not kill instance: %w", err) - } - - // We don't want to kill too many instances at once, so let's - // return after the first successfully killed one. - logger.Info("instance successfully removed from the ASG and terminated") - - return nil - } - } - } - - decision := state.Decide(cfg.AutoscalingMaxCreate, cfg.AutoscalingMaxKill) - - if decision.ScalingDirection == internal.ScalingDirectionNone { - logger.Info("no scaling decision to be made") - return nil - } - - if decision.ScalingDirection == internal.ScalingDirectionUp { - logger.With("instances", decision.ScalingSize).Info("scaling up the ASG") - - if err := controller.ScaleUpASG(ctx, *asg.DesiredCapacity+int32(decision.ScalingSize)); err != nil { - return fmt.Errorf("could not scale up ASG: %w", err) - } - - return nil - } - - // If we got this far, we're scaling down. - logger.With("instances", decision.ScalingSize).Info("scaling down ASG") - - idleWorkers := state.IdleWorkers() - - for i := 0; i < decision.ScalingSize; i++ { - worker := idleWorkers[i] - - _, instanceID, _ := worker.InstanceIdentity() - - logger = logger.With( - "worker_id", worker.ID, - "instance_id", instanceID, - ) - logger.Info("scaling down ASG and killing worker") - - drained, err := controller.DrainWorker(ctx, worker.ID) - if err != nil { - return fmt.Errorf("could not drain worker: %w", err) - } - - if !drained { - logger.Warn("worker was busy, stopping the scaling down process") - return nil - } - - if err := controller.KillInstance(ctx, string(instanceID)); err != nil { - return fmt.Errorf("could not kill instance: %w", err) - } - } - - return nil + return internal.NewAutoScaler(controller, logger).Scale(ctx, cfg) } diff --git a/internal/auto_scaler.go b/internal/auto_scaler.go new file mode 100644 index 0000000..08aa803 --- /dev/null +++ b/internal/auto_scaler.go @@ -0,0 +1,141 @@ +package internal + +import ( + "context" + "fmt" + "time" + + autoscalingtypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "golang.org/x/exp/slog" +) + +//go:generate mockery --output ./ --name ControllerInterface --filename mock_controller_test.go --outpkg internal_test --structname MockController +type ControllerInterface interface { + DescribeInstances(ctx context.Context, instanceIDs []string) (instances []ec2types.Instance, err error) + GetAutoscalingGroup(ctx context.Context) (out *autoscalingtypes.AutoScalingGroup, err error) + GetWorkerPool(ctx context.Context) (out *WorkerPool, err error) + DrainWorker(ctx context.Context, workerID string) (drained bool, err error) + KillInstance(ctx context.Context, instanceID string) (err error) + ScaleUpASG(ctx context.Context, desiredCapacity int32) (err error) +} + +type AutoScaler struct { + controller ControllerInterface + logger *slog.Logger +} + +func NewAutoScaler(controller ControllerInterface, logger *slog.Logger) *AutoScaler { + return &AutoScaler{controller: controller, logger: logger} +} + +func (s AutoScaler) Scale(ctx context.Context, cfg RuntimeConfig) error { + logger := s.logger.With( + "asg_arn", cfg.AutoscalingGroupARN, + "worker_pool_id", cfg.SpaceliftWorkerPoolID, + ) + + workerPool, err := s.controller.GetWorkerPool(ctx) + if err != nil { + return fmt.Errorf("could not get worker pool: %w", err) + } + + asg, err := s.controller.GetAutoscalingGroup(ctx) + if err != nil { + return fmt.Errorf("could not get autoscaling group: %w", err) + } + + state, err := NewState(workerPool, asg) + if err != nil { + return fmt.Errorf("could not create state: %w", err) + } + + // Let's make sure that for each of the in-service instances we have a + // corresponding worker in Spacelift, or that we have "stray" machines. + if strayInstances := state.StrayInstances(); len(strayInstances) > 0 { + // There's a question of what to do with the "stray" machines. The + // decision will be made based on the creation timestamp. + instances, err := s.controller.DescribeInstances(ctx, strayInstances) + if err != nil { + return fmt.Errorf("could not list EC2 instances: %w", err) + } + + for _, instance := range instances { + logger = logger.With("instance_id", *instance.InstanceId) + instanceAge := time.Since(*instance.LaunchTime) + + logger = logger.With( + "launch_timestamp", instance.LaunchTime.Unix(), + "instance_age", instanceAge, + ) + + // If the machine was only created recently (say a generous window of 10 + // minutes), it is possible that it hasn't managed to register itself with + // Spacelift yet. But if it's been around for a while we will want to kill + // it and remove it from the ASG. + if instanceAge > 10*time.Minute { + logger.Warn("instance has no corresponding worker in Spacelift, removing from the ASG") + + if err := s.controller.KillInstance(ctx, *instance.InstanceId); err != nil { + return fmt.Errorf("could not kill instance: %w", err) + } + + // We don't want to kill too many instances at once, so let's + // return after the first successfully killed one. + logger.Info("instance successfully removed from the ASG and terminated") + + return nil + } + } + } + + decision := state.Decide(cfg.AutoscalingMaxCreate, cfg.AutoscalingMaxKill) + + if decision.ScalingDirection == ScalingDirectionNone { + logger.Info("no scaling decision to be made") + return nil + } + + if decision.ScalingDirection == ScalingDirectionUp { + logger.With("instances", decision.ScalingSize).Info("scaling up the ASG") + + if err := s.controller.ScaleUpASG(ctx, *asg.DesiredCapacity+int32(decision.ScalingSize)); err != nil { + return fmt.Errorf("could not scale up ASG: %w", err) + } + + return nil + } + + // If we got this far, we're scaling down. + logger.With("instances", decision.ScalingSize).Info("scaling down ASG") + + idleWorkers := state.IdleWorkers() + + for i := 0; i < decision.ScalingSize; i++ { + worker := idleWorkers[i] + + _, instanceID, _ := worker.InstanceIdentity() + + logger = logger.With( + "worker_id", worker.ID, + "instance_id", instanceID, + ) + logger.Info("scaling down ASG and killing worker") + + drained, err := s.controller.DrainWorker(ctx, worker.ID) + if err != nil { + return fmt.Errorf("could not drain worker: %w", err) + } + + if !drained { + logger.Warn("worker was busy, stopping the scaling down process") + return nil + } + + if err := s.controller.KillInstance(ctx, string(instanceID)); err != nil { + return fmt.Errorf("could not kill instance: %w", err) + } + } + + return nil +} diff --git a/internal/auto_scaler_test.go b/internal/auto_scaler_test.go new file mode 100644 index 0000000..fb834a6 --- /dev/null +++ b/internal/auto_scaler_test.go @@ -0,0 +1,122 @@ +package internal_test + +import ( + "bytes" + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slog" + + "github.com/spacelift-io/awsautoscalr/internal" +) + +func TestAutoScalerScalingNone(t *testing.T) { + var buf bytes.Buffer + h := slog.NewTextHandler(&buf, nil) + + cfg := internal.RuntimeConfig{} + + ctrl := new(MockController) + defer ctrl.AssertExpectations(t) + + scaler := internal.NewAutoScaler(ctrl, slog.New(h)) + + ctrl.On("GetWorkerPool", mock.Anything).Return(&internal.WorkerPool{ + Workers: []internal.Worker{ + { + ID: "1", + Metadata: `{"asg_id": "group", "instance_id": "instance"}`, + }, + }, + }, nil) + ctrl.On("GetAutoscalingGroup", mock.Anything).Return(&types.AutoScalingGroup{ + AutoScalingGroupName: ptr("group"), + MinSize: ptr(int32(1)), + MaxSize: ptr(int32(3)), + DesiredCapacity: ptr(int32(2)), + }, nil) + err := scaler.Scale(context.Background(), cfg) + require.NoError(t, err) +} + +func TestAutoScalerScalingUp(t *testing.T) { + var buf bytes.Buffer + h := slog.NewTextHandler(&buf, nil) + + cfg := internal.RuntimeConfig{} + + ctrl := new(MockController) + defer ctrl.AssertExpectations(t) + + scaler := internal.NewAutoScaler(ctrl, slog.New(h)) + + ctrl.On("GetWorkerPool", mock.Anything).Return(&internal.WorkerPool{ + Workers: []internal.Worker{ + { + ID: "1", + Metadata: `{"asg_id": "group", "instance_id": "instance"}`, + }, + }, + PendingRuns: 2, + }, nil) + ctrl.On("GetAutoscalingGroup", mock.Anything).Return(&types.AutoScalingGroup{ + AutoScalingGroupName: ptr("group"), + MinSize: ptr(int32(1)), + MaxSize: ptr(int32(3)), + DesiredCapacity: ptr(int32(2)), + Instances: []types.Instance{ + {}, + }, + }, nil) + ctrl.On("ScaleUpASG", mock.Anything, int32(2)).Return(nil) + err := scaler.Scale(context.Background(), cfg) + require.NoError(t, err) +} + +func TestAutoScalerScalingDown(t *testing.T) { + var buf bytes.Buffer + h := slog.NewTextHandler(&buf, nil) + + cfg := internal.RuntimeConfig{ + AutoscalingMaxKill: 1, + } + + ctrl := new(MockController) + defer ctrl.AssertExpectations(t) + + scaler := internal.NewAutoScaler(ctrl, slog.New(h)) + + ctrl.On("GetWorkerPool", mock.Anything).Return(&internal.WorkerPool{ + Workers: []internal.Worker{ + { + ID: "1", + Metadata: `{"asg_id": "group", "instance_id": "instance"}`, + }, + { + ID: "2", + Metadata: `{"asg_id": "group", "instance_id": "instance2"}`, + }, + }, + }, nil) + ctrl.On("GetAutoscalingGroup", mock.Anything).Return(&types.AutoScalingGroup{ + AutoScalingGroupName: ptr("group"), + MinSize: ptr(int32(1)), + MaxSize: ptr(int32(3)), + DesiredCapacity: ptr(int32(2)), + Instances: []types.Instance{ + {}, + {}, + }, + }, nil) + ctrl.On("DrainWorker", mock.Anything, "1").Return(true, nil) + ctrl.On("KillInstance", mock.Anything, "instance").Return(nil) + err := scaler.Scale(context.Background(), cfg) + require.NoError(t, err) +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/internal/mock_controller_test.go b/internal/mock_controller_test.go new file mode 100644 index 0000000..f8a4eec --- /dev/null +++ b/internal/mock_controller_test.go @@ -0,0 +1,188 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package internal_test + +import ( + context "context" + + autoscalingtypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" + + internal "github.com/spacelift-io/awsautoscalr/internal" + + mock "github.com/stretchr/testify/mock" + + types "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +// MockController is an autogenerated mock type for the ControllerInterface type +type MockController struct { + mock.Mock +} + +// DescribeInstances provides a mock function with given fields: ctx, instanceIDs +func (_m *MockController) DescribeInstances(ctx context.Context, instanceIDs []string) ([]types.Instance, error) { + ret := _m.Called(ctx, instanceIDs) + + if len(ret) == 0 { + panic("no return value specified for DescribeInstances") + } + + var r0 []types.Instance + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []string) ([]types.Instance, error)); ok { + return rf(ctx, instanceIDs) + } + if rf, ok := ret.Get(0).(func(context.Context, []string) []types.Instance); ok { + r0 = rf(ctx, instanceIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Instance) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, instanceIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DrainWorker provides a mock function with given fields: ctx, workerID +func (_m *MockController) DrainWorker(ctx context.Context, workerID string) (bool, error) { + ret := _m.Called(ctx, workerID) + + if len(ret) == 0 { + panic("no return value specified for DrainWorker") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return rf(ctx, workerID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = rf(ctx, workerID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, workerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAutoscalingGroup provides a mock function with given fields: ctx +func (_m *MockController) GetAutoscalingGroup(ctx context.Context) (*autoscalingtypes.AutoScalingGroup, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetAutoscalingGroup") + } + + var r0 *autoscalingtypes.AutoScalingGroup + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*autoscalingtypes.AutoScalingGroup, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *autoscalingtypes.AutoScalingGroup); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*autoscalingtypes.AutoScalingGroup) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetWorkerPool provides a mock function with given fields: ctx +func (_m *MockController) GetWorkerPool(ctx context.Context) (*internal.WorkerPool, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerPool") + } + + var r0 *internal.WorkerPool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*internal.WorkerPool, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *internal.WorkerPool); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internal.WorkerPool) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// KillInstance provides a mock function with given fields: ctx, instanceID +func (_m *MockController) KillInstance(ctx context.Context, instanceID string) error { + ret := _m.Called(ctx, instanceID) + + if len(ret) == 0 { + panic("no return value specified for KillInstance") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, instanceID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ScaleUpASG provides a mock function with given fields: ctx, desiredCapacity +func (_m *MockController) ScaleUpASG(ctx context.Context, desiredCapacity int32) error { + ret := _m.Called(ctx, desiredCapacity) + + if len(ret) == 0 { + panic("no return value specified for ScaleUpASG") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int32) error); ok { + r0 = rf(ctx, desiredCapacity) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewMockController creates a new instance of MockController. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockController(t interface { + mock.TestingT + Cleanup(func()) +}) *MockController { + mock := &MockController{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}