From fd2035d4ce66d3994e154b707c8d8fb0a01ffaa1 Mon Sep 17 00:00:00 2001 From: Sean McGary Date: Wed, 5 Feb 2025 12:44:10 -0600 Subject: [PATCH] fix: normalize addresses used as inputs from rpcs --- pkg/rpcServer/rewardsHandlers.go | 16 +++++-- pkg/service/rewardsDataService/rewards.go | 42 ++++++++++++------- .../rewardsDataService/rewards_test.go | 4 ++ 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/pkg/rpcServer/rewardsHandlers.go b/pkg/rpcServer/rewardsHandlers.go index 5e86a592..2f697084 100644 --- a/pkg/rpcServer/rewardsHandlers.go +++ b/pkg/rpcServer/rewardsHandlers.go @@ -314,6 +314,13 @@ func (rpc *RpcServer) GetAvailableRewardsTokens(ctx context.Context, req *reward }, nil } +func withDefaultValue(value string, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} + func (rpc *RpcServer) GetSummarizedRewardsForEarner(ctx context.Context, req *rewardsV1.GetSummarizedRewardsForEarnerRequest) (*rewardsV1.GetSummarizedRewardsForEarnerResponse, error) { earner := req.GetEarnerAddress() blockHeight := req.GetBlockHeight() @@ -329,12 +336,13 @@ func (rpc *RpcServer) GetSummarizedRewardsForEarner(ctx context.Context, req *re return &rewardsV1.GetSummarizedRewardsForEarnerResponse{ Rewards: utils.Map(summarizedRewards, func(r *rewardsDataService.SummarizedReward, i uint64) *rewardsV1.SummarizedEarnerReward { + return &rewardsV1.SummarizedEarnerReward{ Token: r.Token, - Earned: r.Earned, - Active: r.Active, - Claimed: r.Claimed, - Claimable: r.Claimable, + Earned: withDefaultValue(r.Earned, "0"), + Active: withDefaultValue(r.Active, "0"), + Claimed: withDefaultValue(r.Claimed, "0"), + Claimable: withDefaultValue(r.Claimable, "0"), } }), }, nil diff --git a/pkg/service/rewardsDataService/rewards.go b/pkg/service/rewardsDataService/rewards.go index c03e0b9c..9dcde1db 100644 --- a/pkg/service/rewardsDataService/rewards.go +++ b/pkg/service/rewardsDataService/rewards.go @@ -56,17 +56,29 @@ type TotalClaimedReward struct { Amount string } +func lowercaseTokenList(tokens []string) []string { + return utils.Map(tokens, func(token string, i uint64) string { + return strings.ToLower(token) + }) +} + func (rds *RewardsDataService) GetTotalClaimedRewards(ctx context.Context, earner string, tokens []string, blockHeight uint64) ([]*TotalClaimedReward, error) { blockHeight, err := rds.BaseDataService.GetCurrentBlockHeightIfNotPresent(ctx, blockHeight) if err != nil { return nil, err } + if earner == "" { + return nil, fmt.Errorf("earner is required") + } + earner = strings.ToLower(earner) + tokens = lowercaseTokenList(tokens) + query := ` select earner, token, - sum(claimed_amount) as amount + coalesce(sum(claimed_amount), 0) as amount from rewards_claimed as rc where earner = @earner @@ -134,14 +146,12 @@ func (rds *RewardsDataService) ListClaimedRewardsByBlockRange( } if earner != "" { query += " and earner = @earner" - args = append(args, sql.Named("earner", earner)) + args = append(args, sql.Named("earner", strings.ToLower(earner))) } if len(tokens) > 0 { query += " and token in (?)" - formattedTokens := utils.Map(tokens, func(token string, i uint64) string { - return strings.ToLower(token) - }) - args = append(args, sql.Named("tokens", formattedTokens)) + tokens = lowercaseTokenList(tokens) + args = append(args, sql.Named("tokens", tokens)) } query += " order by block_number, log_index" @@ -170,6 +180,7 @@ func (rds *RewardsDataService) GetTotalRewardsForEarner( if earner == "" { return nil, fmt.Errorf("earner is required") } + earner = strings.ToLower(earner) snapshot, err := rds.findDistributionRootClosestToBlockHeight(blockHeight, claimable) if err != nil { @@ -192,7 +203,7 @@ func (rds *RewardsDataService) GetTotalRewardsForEarner( ) select token, - sum(amount) as amount + coalesce(sum(amount), 0) as amount from token_snapshots group by 1 ` @@ -202,10 +213,8 @@ func (rds *RewardsDataService) GetTotalRewardsForEarner( } if len(tokens) > 0 { query += " and token in (?)" - formattedTokens := utils.Map(tokens, func(token string, i uint64) string { - return strings.ToLower(token) - }) - args = append(args, sql.Named("tokens", formattedTokens)) + tokens = lowercaseTokenList(tokens) + args = append(args, sql.Named("tokens", tokens)) } rewardAmounts := make([]*RewardAmount, 0) @@ -232,6 +241,8 @@ func (rds *RewardsDataService) GetClaimableRewardsForEarner( if earner == "" { return nil, nil, fmt.Errorf("earner is required") } + earner = strings.ToLower(earner) + snapshot, err := rds.findDistributionRootClosestToBlockHeight(blockHeight, true) if err != nil { return nil, nil, err @@ -281,10 +292,8 @@ func (rds *RewardsDataService) GetClaimableRewardsForEarner( } if len(tokens) > 0 { query += " and token in (?)" - formattedTokens := utils.Map(tokens, func(token string, i uint64) string { - return strings.ToLower(token) - }) - args = append(args, sql.Named("tokens", formattedTokens)) + tokens = lowercaseTokenList(tokens) + args = append(args, sql.Named("tokens", tokens)) } claimableRewards := make([]*RewardAmount, 0) @@ -372,6 +381,8 @@ func (rds *RewardsDataService) GetSummarizedRewards(ctx context.Context, earner if earner == "" { return nil, fmt.Errorf("earner is required") } + earner = strings.ToLower(earner) + tokens = lowercaseTokenList(tokens) blockHeight, err := rds.BaseDataService.GetCurrentBlockHeightIfNotPresent(context.Background(), blockHeight) if err != nil { @@ -487,6 +498,7 @@ func (rds *RewardsDataService) ListAvailableRewardsTokens(ctx context.Context, e if earner == "" { return nil, fmt.Errorf("earner is required") } + earner = strings.ToLower(earner) blockHeight, err := rds.BaseDataService.GetCurrentBlockHeightIfNotPresent(ctx, blockHeight) if err != nil { diff --git a/pkg/service/rewardsDataService/rewards_test.go b/pkg/service/rewardsDataService/rewards_test.go index d0fca77b..e0a2274b 100644 --- a/pkg/service/rewardsDataService/rewards_test.go +++ b/pkg/service/rewardsDataService/rewards_test.go @@ -125,6 +125,10 @@ func Test_RewardsDataService(t *testing.T) { r, err := rds.GetSummarizedRewards(context.Background(), earner, nil, blockNumber) assert.Nil(t, err) assert.NotNil(t, r) + fmt.Printf("Summarized rewards: %+v\n", r) + for _, sr := range r { + fmt.Printf(" %+v\n", sr) + } }) t.Run("Test ListAvailableRewardsTokens", func(t *testing.T) {