diff --git a/pkg/aliauth/aliauth.go b/pkg/aliauth/aliauth.go new file mode 100644 index 000000000..62474cb86 --- /dev/null +++ b/pkg/aliauth/aliauth.go @@ -0,0 +1,133 @@ +package aliauth + +import ( + "fmt" + "time" + + openapi "github.com/alibabacloud-go/darabonba-openapi/client" + "github.com/alibabacloud-go/sts-20150401/client" + "github.com/aliyun/aliyun-odps-go-sdk/odps/account" + + openapiV2 "github.com/alibabacloud-go/darabonba-openapi/v2/client" +) + +var assumeRoleDefaultDuration = time.Hour +var durationSeconds = int64(assumeRoleDefaultDuration.Seconds()) + +type aliAuthAccount struct { + Account account.Account + ExpiryTime *time.Time // Only set for STS accounts +} + +type AliAuthConfig struct { + account *aliAuthAccount + regionID string +} + +func NewConfig(ramUserAccessKeyID, ramUserAccessKeySecret, regionID, ramRole, roleSessionName string) (*AliAuthConfig, error) { + if ramUserAccessKeyID == "" || ramUserAccessKeySecret == "" || regionID == "" { + return nil, fmt.Errorf("access key ID, secret, and region ID are required") + } + + if ramRole != "" && roleSessionName == "" { + return nil, fmt.Errorf("role session name is required when assuming a role") + } + + var authAccount *aliAuthAccount + if ramRole == "" { + authAccount = &aliAuthAccount{ + Account: account.NewAliyunAccount(ramUserAccessKeyID, ramUserAccessKeySecret), + } + } else { + stsAcc, expiry, err := getSTSAccount(ramRole, roleSessionName, ramUserAccessKeyID, ramUserAccessKeySecret, regionID) + if err != nil { + return nil, err + } + authAccount = &aliAuthAccount{ + Account: stsAcc, + ExpiryTime: &expiry, // Ensure expiry time is always set + } + } + + return &AliAuthConfig{account: authAccount, regionID: regionID}, nil +} + +func (a *AliAuthConfig) IsConfigValid() bool { + switch a.account.Account.(type) { + case *account.AliyunAccount: + return true + case *account.StsAccount: + if a.account.ExpiryTime == nil { + return false // Safety check to prevent nil dereference + } + return time.Now().Before(*a.account.ExpiryTime) + default: + return false + } +} + +func (a *AliAuthConfig) GetAccount() account.Account { + return a.account.Account +} + +func (a *AliAuthConfig) GetCredentials() (*openapiV2.Config, error) { + var accessKeyId, accessKeySecret, securityToken string + + switch acc := a.account.Account.(type) { + case *account.AliyunAccount: + accessKeyId = acc.AccessId() + accessKeySecret = acc.AccessKey() + case *account.StsAccount: + cred, err := acc.Credential() + if err != nil { + return &openapiV2.Config{}, fmt.Errorf("failed to get STS credentials: %w", err) + } + + if cred.AccessKeyId == nil || cred.AccessKeySecret == nil || cred.SecurityToken == nil { + return nil, fmt.Errorf("STS credentials contain nil values") + } + + accessKeyId = *cred.AccessKeyId + accessKeySecret = *cred.AccessKeySecret + securityToken = *cred.SecurityToken + default: + return &openapiV2.Config{}, fmt.Errorf("unknown account type") + } + + return &openapiV2.Config{ + AccessKeyId: &accessKeyId, + AccessKeySecret: &accessKeySecret, + SecurityToken: &securityToken, + RegionId: &a.regionID, + }, nil +} + +// getSTSAccount obtains an STS account by assuming a RAM role +func getSTSAccount(ramRole, roleSessionName, accessKeyID, accessKeySecret, regionID string) (*account.StsAccount, time.Time, error) { + stsEndpoint := fmt.Sprintf("sts.%s.aliyuncs.com", regionID) + + config := &openapi.Config{ + AccessKeyId: &accessKeyID, + AccessKeySecret: &accessKeySecret, + Endpoint: &stsEndpoint, + } + + stsClient, err := client.NewClient(config) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to initialize STS client: %w", err) + } + + request := &client.AssumeRoleRequest{ + RoleArn: &ramRole, + RoleSessionName: &roleSessionName, + DurationSeconds: &durationSeconds, + } + + res, err := stsClient.AssumeRole(request) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to assume role: %w", err) + } + + expiryTimeStamp := time.Now().Add(assumeRoleDefaultDuration) + return account.NewStsAccount(*res.Body.Credentials.AccessKeyId, *res.Body.Credentials.AccessKeySecret, *res.Body.Credentials.SecurityToken), expiryTimeStamp, nil +} diff --git a/pkg/stsClient/stsClient.go b/pkg/stsClient/stsClient.go deleted file mode 100644 index b8a2c2dc7..000000000 --- a/pkg/stsClient/stsClient.go +++ /dev/null @@ -1,95 +0,0 @@ -package sts - -import ( - "fmt" - "time" - - openapi "github.com/alibabacloud-go/darabonba-openapi/client" - openapiV2 "github.com/alibabacloud-go/darabonba-openapi/v2/client" - - "github.com/alibabacloud-go/sts-20150401/client" -) - -var assumeRoleDurationHours int64 = 1 - -type StsClient struct { - client *client.Client - expiryTimeStamp time.Time -} - -type Sts struct { - clients map[string]*StsClient -} - -func NewSTS() *Sts { - return &Sts{ - clients: make(map[string]*StsClient), - } -} - -func (s *Sts) IsSTSTokenValid(clientIdentifier string) bool { - client := s.clients[clientIdentifier] - if client == nil { - return false - } - - return time.Now().Before(client.expiryTimeStamp) -} - -func NewSTSClient(userAccessKeyID, userSecretAccessKey, regionID string) (*client.Client, error) { - stsEndpoint := fmt.Sprintf("sts.%s.aliyuncs.com", regionID) - - config := &openapi.Config{ - AccessKeyId: &userAccessKeyID, - AccessKeySecret: &userSecretAccessKey, - Endpoint: &stsEndpoint, - } - - stsClient, err := client.NewClient(config) - if err != nil { - return nil, fmt.Errorf("failed to initialize STS client: %w", err) - } - - return stsClient, nil -} - -func (s *Sts) GetSTSClient(clientIdentifier, userAccessKeyID, userSecret, regionID string) (*client.Client, error) { - if c, ok := s.clients[clientIdentifier]; ok && s.IsSTSTokenValid(clientIdentifier) { - return c.client, nil - } - - stsClient, err := NewSTSClient(userAccessKeyID, userSecret, regionID) - if err != nil { - return nil, err - } - - s.clients[clientIdentifier] = &StsClient{ - client: stsClient, - expiryTimeStamp: time.Now().Add(time.Duration(assumeRoleDurationHours) * time.Hour), - } - - return stsClient, nil -} - -func AssumeRole(stsClient *client.Client, roleArn, roleSessionName, regionID string) (*openapiV2.Config, error) { - durationSeconds := assumeRoleDurationHours * int64(time.Hour.Seconds()) - request := client.AssumeRoleRequest{ - RoleArn: &roleArn, - RoleSessionName: &roleSessionName, - DurationSeconds: &durationSeconds, - } - - res, err := stsClient.AssumeRole(&request) - if err != nil { - return nil, fmt.Errorf("failed to assume role: %w", err) - } - - config := &openapiV2.Config{ - AccessKeyId: res.Body.Credentials.AccessKeyId, - AccessKeySecret: res.Body.Credentials.AccessKeySecret, - SecurityToken: res.Body.Credentials.SecurityToken, - RegionId: ®ionID, - } - - return config, nil -} diff --git a/plugins/providers/maxcompute/provider.go b/plugins/providers/maxcompute/provider.go index 5f6c6c024..a27070d4f 100644 --- a/plugins/providers/maxcompute/provider.go +++ b/plugins/providers/maxcompute/provider.go @@ -10,18 +10,15 @@ import ( maxcompute "github.com/alibabacloud-go/maxcompute-20220104/client" "github.com/aliyun/aliyun-odps-go-sdk/odps" - "github.com/aliyun/aliyun-odps-go-sdk/odps/account" "github.com/aliyun/aliyun-odps-go-sdk/odps/restclient" "github.com/aliyun/aliyun-odps-go-sdk/odps/security" "github.com/bearaujus/bptr" pv "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" + aliauth "github.com/goto/guardian/pkg/aliauth" "github.com/goto/guardian/pkg/log" - sts "github.com/goto/guardian/pkg/stsClient" "github.com/goto/guardian/utils" "golang.org/x/net/context" - - openapiV2 "github.com/alibabacloud-go/darabonba-openapi/v2/client" ) //go:generate mockery --name=encryptor --exported --with-expecter @@ -30,13 +27,13 @@ type encryptor interface { } type ODPSClient struct { - client *odps.Odps - stsClientExist bool + client *odps.Odps + authConfig *aliauth.AliAuthConfig } type RestClient struct { - client *maxcompute.Client - stsClientExist bool + client *maxcompute.Client + authConfig *aliauth.AliAuthConfig } type provider struct { @@ -46,7 +43,6 @@ type provider struct { encryptor encryptor restClients map[string]RestClient odpsClients map[string]ODPSClient - sts *sts.Sts logger log.Logger mu sync.Mutex } @@ -61,7 +57,6 @@ func New( encryptor: encryptor, restClients: make(map[string]RestClient), odpsClients: make(map[string]ODPSClient), - sts: sts.NewSTS(), logger: logger, } @@ -393,125 +388,84 @@ func (p *provider) getRestClient(pc *domain.ProviderConfig) (*maxcompute.Client, return nil, err } - ramRole, stsClientID := p.getRamRoleAndStsClientID("rest", creds, "") - if restClient, ok := p.getCachedRestClient(ramRole, stsClientID, pc.URN); ok { - return restClient, nil - } + ramRole := p.getRamRole(creds, "") + cachedClientKey := fmt.Sprintf("%s:%s", creds.AccessKeyID, ramRole) - endpoint := fmt.Sprintf("maxcompute.%s.aliyuncs.com", creds.RegionID) - var clientConfig *openapiV2.Config - if creds.RAMRole != "" { - stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID) - if err != nil { - return nil, err + if c, exists := p.restClients[cachedClientKey]; exists { + if c.authConfig.IsConfigValid() { + return c.client, nil } + p.mu.Lock() + delete(p.restClients, cachedClientKey) + p.mu.Unlock() + } - clientConfig, err = sts.AssumeRole(stsClient, creds.RAMRole, pc.URN, creds.RegionID) - if err != nil { - return nil, err - } - clientConfig.Endpoint = &endpoint - } else { - clientConfig = &openapiV2.Config{ - AccessKeyId: &creds.AccessKeyID, - AccessKeySecret: &creds.AccessKeySecret, - Endpoint: &endpoint, - } + authCofig, err := aliauth.NewConfig(creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID, ramRole, pc.URN) + if err != nil { + return nil, err } - restClient, err := maxcompute.NewClient(clientConfig) + authCreds, err := authCofig.GetCredentials() if err != nil { return nil, err } - p.mu.Lock() - if creds.RAMRole != "" { - p.restClients[creds.RAMRole] = RestClient{client: restClient, stsClientExist: true} - } else { - p.restClients[pc.URN] = RestClient{client: restClient} + endpoint := fmt.Sprintf("maxcompute.%s.aliyuncs.com", creds.RegionID) + authCreds.Endpoint = &endpoint + restClient, err := maxcompute.NewClient(authCreds) + if err != nil { + return nil, err } - p.mu.Unlock() + + p.mu.Lock() + defer p.mu.Unlock() + p.restClients[cachedClientKey] = RestClient{client: restClient, authConfig: authCofig} + return restClient, nil } -func (p *provider) getOdpsClient(pc *domain.ProviderConfig, ramRoleFromAppeal string) (*odps.Odps, error) { +func (p *provider) getOdpsClient(pc *domain.ProviderConfig, overrideRamRole string) (*odps.Odps, error) { creds, err := p.getCreds(pc) if err != nil { return nil, err } - // getting client from memory cache - ramRole, stsClientID := p.getRamRoleAndStsClientID("odps", creds, ramRoleFromAppeal) - if odpsClient, ok := p.getCachedOdpsClient(ramRole, stsClientID, pc.URN); ok { - return odpsClient, nil - } + ramRole := p.getRamRole(creds, overrideRamRole) + cachedClientKey := fmt.Sprintf("%s:%s", creds.AccessKeyID, ramRole) - // initialize new client - var acc account.Account - if ramRole != "" { - stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID) - if err != nil { - return nil, err + if c, exists := p.odpsClients[cachedClientKey]; exists { + if c.authConfig.IsConfigValid() { + return c.client, nil } + p.mu.Lock() + delete(p.odpsClients, cachedClientKey) + p.mu.Unlock() + } - clientConfig, err := sts.AssumeRole(stsClient, ramRole, pc.URN, creds.RegionID) - if err != nil { - return nil, err - } - acc = account.NewStsAccount(*clientConfig.AccessKeyId, *clientConfig.AccessKeySecret, *clientConfig.SecurityToken) - } else { - acc = account.NewAliyunAccount(creds.AccessKeyID, creds.AccessKeySecret) + authConfig, err := aliauth.NewConfig(creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID, ramRole, pc.URN) + if err != nil { + return nil, err } + endpoint := fmt.Sprintf("http://service.%s.maxcompute.aliyun.com/api", creds.RegionID) - client := odps.NewOdps(acc, endpoint) + client := odps.NewOdps(authConfig.GetAccount(), endpoint) p.mu.Lock() - if ramRoleFromAppeal != "" { - p.odpsClients[ramRoleFromAppeal] = ODPSClient{client: client, stsClientExist: true} - } else { - p.odpsClients[pc.URN] = ODPSClient{client: client} - } - p.mu.Unlock() + defer p.mu.Unlock() + p.odpsClients[cachedClientKey] = ODPSClient{client: client, authConfig: authConfig} return client, nil } -func (p *provider) getRamRoleAndStsClientID(clientType string, creds *credentials, ramRoleFromAppeal string) (string, string) { +func (p *provider) getRamRole(creds *credentials, overrideRamRole string) string { var ramRole string switch { - case ramRoleFromAppeal != "": - ramRole = ramRoleFromAppeal + case overrideRamRole != "": + ramRole = overrideRamRole case creds.RAMRole != "": ramRole = creds.RAMRole } - stsClientID := clientType + "-" + ramRole - return ramRole, stsClientID -} - -func (p *provider) getCachedOdpsClient(ramRole, stsClientID, urn string) (*odps.Odps, bool) { - c, ok := p.odpsClients[ramRole] - if ramRole != "" && ok && c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) { - return c.client, true - } - - if c, ok := p.odpsClients[urn]; ok { - return c.client, true - } - - return nil, false -} - -func (p *provider) getCachedRestClient(ramRole, stsClientID, urn string) (*maxcompute.Client, bool) { - c, ok := p.restClients[ramRole] - if ramRole != "" && ok && c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) { - return c.client, true - } - - if c, ok := p.restClients[urn]; ok { - return c.client, true - } - - return nil, false + return ramRole } func getParametersFromGrant[T any](g domain.Grant, key string) (T, bool, error) { diff --git a/plugins/providers/oss/provider.go b/plugins/providers/oss/provider.go index a60722c81..6d9f882b8 100644 --- a/plugins/providers/oss/provider.go +++ b/plugins/providers/oss/provider.go @@ -12,7 +12,7 @@ import ( pv "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" - sts "github.com/goto/guardian/pkg/stsClient" + aliauth "github.com/goto/guardian/pkg/aliauth" "github.com/goto/guardian/utils" "github.com/aliyun/aliyun-oss-go-sdk/oss" @@ -37,8 +37,8 @@ type Policy struct { } type OSSClient struct { - client *oss.Client - stsClientExist bool + client *oss.Client + authConfig aliauth.AliAuthConfig } type provider struct { pv.UnimplementedClient @@ -47,7 +47,6 @@ type provider struct { encryptor encryptor ossClients map[string]OSSClient - sts *sts.Sts mu sync.Mutex } @@ -57,7 +56,6 @@ func NewProvider(typeName string, encryptor encryptor) *provider { typeName: typeName, encryptor: encryptor, ossClients: make(map[string]OSSClient), - sts: sts.NewSTS(), } } @@ -484,67 +482,64 @@ func (p *provider) getCreds(pc *domain.ProviderConfig) (*Credentials, error) { return creds, nil } -func (p *provider) getOSSClient(pc *domain.ProviderConfig, ramRole string) (*oss.Client, error) { +func (p *provider) getOSSClient(pc *domain.ProviderConfig, overrideRamRole string) (*oss.Client, error) { creds, err := p.getCreds(pc) if err != nil { return nil, err } + ramRole := overrideRamRole if ramRole == "" { ramRole = creds.RAMRole } - stsClientID := "oss-" + ramRole - if ossClient, ok := p.getCachedOSSClient(ramRole, stsClientID, pc.URN); ok { - return ossClient, nil - } - - endpoint := fmt.Sprintf("https://oss-%s.aliyuncs.com", creds.RegionID) - var client *oss.Client - if ramRole != "" { - stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID) - if err != nil { - return nil, err - } + cachedClientKey := fmt.Sprintf("%s:%s", creds.AccessKeyID, ramRole) - clientConfig, err := sts.AssumeRole(stsClient, creds.RAMRole, pc.URN, creds.RegionID) - if err != nil { - return nil, err + // Check cache for existing client + if cachedClient, exists := p.ossClients[cachedClientKey]; exists { + if cachedClient.authConfig.IsConfigValid() { + return cachedClient.client, nil } + p.mu.Lock() + delete(p.ossClients, cachedClientKey) + p.mu.Unlock() + } - clientOpts := oss.SecurityToken(*clientConfig.SecurityToken) - client, err = oss.New(endpoint, *clientConfig.AccessKeyId, *clientConfig.AccessKeySecret, clientOpts) - if err != nil { - return nil, fmt.Errorf("failed to initialize oss client: %w", err) - } - } else { - client, err = oss.New(endpoint, creds.AccessKeyID, creds.AccessKeySecret) - if err != nil { - return nil, fmt.Errorf("failed to initialize oss client: %w", err) - } + // Create new OSS client + client, authConfig, err := p.newOSSClient(creds, ramRole, pc.URN) + if err != nil { + return nil, err } + // Store in cache p.mu.Lock() - if ramRole != "" { - p.ossClients[ramRole] = OSSClient{client: client, stsClientExist: true} - } else { - p.ossClients[pc.URN] = OSSClient{client: client} - } - p.mu.Unlock() + defer p.mu.Unlock() + p.ossClients[cachedClientKey] = OSSClient{client: client, authConfig: authConfig} return client, nil } -func (p *provider) getCachedOSSClient(ramRole, stsClientID, urn string) (*oss.Client, bool) { - c, ok := p.ossClients[ramRole] - if ramRole != "" && ok && c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) { - return c.client, true +func (p *provider) newOSSClient(creds *Credentials, ramRole, urn string) (*oss.Client, aliauth.AliAuthConfig, error) { + endpoint := fmt.Sprintf("https://oss-%s.aliyuncs.com", creds.RegionID) + + authConfig, err := aliauth.NewConfig(creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID, ramRole, urn) + if err != nil { + return nil, aliauth.AliAuthConfig{}, fmt.Errorf("failed to create auth config: %w", err) + } + + authCreds, err := authConfig.GetCredentials() + if err != nil { + return nil, *authConfig, fmt.Errorf("failed to get credentials: %w", err) } - if c, ok := p.ossClients[urn]; ok { - return c.client, true + var ossClient *oss.Client + // Initialize OSS client with security token if available + if *authCreds.SecurityToken != "" { + ossClient, err = oss.New(endpoint, *authCreds.AccessKeyId, *authCreds.AccessKeySecret, oss.SecurityToken(*authCreds.SecurityToken)) + return ossClient, *authConfig, err } - return nil, false + ossClient, err = oss.New(endpoint, creds.AccessKeyID, creds.AccessKeySecret) + return ossClient, *authConfig, err } func getRAMRole(g domain.Grant) (string, error) { @@ -552,7 +547,7 @@ func getRAMRole(g domain.Grant) (string, error) { if err != nil { return "", err } - return fmt.Sprintf("acs:ram::%s:role/guardian.bot", resourceAccountID), nil + return fmt.Sprintf("acs:ram::%s:role/guardian-bot", resourceAccountID), nil } func getAccountIDFromResource(resource *domain.Resource) (string, error) {