Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validation totalRate to check date overlapped budgets #83

Merged
Prev Previous commit
Next Next commit
fix: refactor CollectibleBudgets and add test cases
  • Loading branch information
dongsam committed Nov 5, 2021
commit 2fe6c6a567e9337dc8ffed5d495302fb50317ce1
14 changes: 1 addition & 13 deletions x/budget/keeper/budget.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func (k Keeper) CollectBudgets(ctx sdk.Context) error {
params := k.GetParams(ctx)
var budgets []types.Budget
if params.EpochBlocks > 0 && ctx.BlockHeight()%int64(params.EpochBlocks) == 0 {
budgets = k.CollectibleBudgets(ctx, params.Budgets)
budgets = types.CollectibleBudgets(params.Budgets, ctx.BlockTime())
}
if len(budgets) == 0 {
return nil
Expand Down Expand Up @@ -71,18 +71,6 @@ func (k Keeper) CollectBudgets(ctx sdk.Context) error {
return nil
}

// CollectibleBudgets returns scan through the budgets registered in params.Budgets
// and returns only the valid and not expired budgets.
func (k Keeper) CollectibleBudgets(ctx sdk.Context, budgets []types.Budget) (collectibleBudgets []types.Budget) {
for _, budget := range budgets {
err := budget.Validate()
if err == nil && budget.Collectible(ctx.BlockTime()) {
collectibleBudgets = append(collectibleBudgets, budget)
}
}
return
}

// GetTotalCollectedCoins returns total collected coins for a budget.
func (k Keeper) GetTotalCollectedCoins(ctx sdk.Context, budgetName string) sdk.Coins {
store := ctx.KVStore(k.storeKey)
Expand Down
2 changes: 1 addition & 1 deletion x/budget/keeper/budget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ func (suite *KeeperTestSuite) TestBudgetChangeSituation() {
height += 1
suite.ctx = suite.ctx.WithBlockHeight(int64(height))
suite.ctx = suite.ctx.WithBlockTime(tc.nextBlockTime)
budgets := suite.keeper.CollectibleBudgets(suite.ctx, params.Budgets)
budgets := suite.keeper.CollectibleBudgets(params.Budgets, suite.ctx.BlockTime())
suite.Require().Len(budgets, tc.collectibleBudgetCount)

// BeginBlocker
Expand Down
11 changes: 11 additions & 0 deletions x/budget/types/budget.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ func (budget Budget) Collectible(blockTime time.Time) bool {
return !budget.StartTime.After(blockTime) && budget.EndTime.After(blockTime)
}

// CollectibleBudgets returns only the valid and started and not expired budgets based on the given block time.
func CollectibleBudgets(budgets []Budget, blockTime time.Time) (collectibleBudgets []Budget) {
for _, budget := range budgets {
err := budget.Validate()
if err == nil && budget.Collectible(blockTime) {
collectibleBudgets = append(collectibleBudgets, budget)
}
}
return
}

// ValidateName is the default validation function for Budget.Name.
// A budget name only allows alphabet letters(`A-Z, a-z`), digit numbers(`0-9`), and `-`.
// It doesn't allow spaces and the maximum length is 50 characters.
Expand Down
79 changes: 49 additions & 30 deletions x/budget/types/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,12 @@ import (
"github.com/tendermint/budget/x/budget/types"
)

func TestParams(t *testing.T) {
require.IsType(t, paramstypes.KeyTable{}, types.ParamKeyTable())

defaultParams := types.DefaultParams()

paramsStr := `epoch_blocks: 1
budgets: []
`
require.Equal(t, paramsStr, defaultParams.String())
}

func TestValidateBudgets(t *testing.T) {
cAddr1 := sdk.AccAddress(address.Module(types.ModuleName, []byte("collectionAddr1")))
cAddr2 := sdk.AccAddress(address.Module(types.ModuleName, []byte("collectionAddr2")))
tAddr1 := sdk.AccAddress(address.Module(types.ModuleName, []byte("budgetSourceAddr1")))
tAddr2 := sdk.AccAddress(address.Module(types.ModuleName, []byte("budgetSourceAddr2")))
budgets := []types.Budget{
var (
cAddr1 = sdk.AccAddress(address.Module(types.ModuleName, []byte("collectionAddr1")))
cAddr2 = sdk.AccAddress(address.Module(types.ModuleName, []byte("collectionAddr2")))
tAddr1 = sdk.AccAddress(address.Module(types.ModuleName, []byte("budgetSourceAddr1")))
tAddr2 = sdk.AccAddress(address.Module(types.ModuleName, []byte("budgetSourceAddr2")))
budgets = []types.Budget{
{
Name: "test",
Rate: sdk.OneDec(),
Expand All @@ -38,28 +27,28 @@ func TestValidateBudgets(t *testing.T) {
EndTime: time.Time{},
},
{
Name: "test2",
Name: "test1",
Rate: sdk.OneDec(),
BudgetSourceAddress: tAddr2.String(),
CollectionAddress: cAddr2.String(),
StartTime: time.Time{},
EndTime: time.Time{},
StartTime: types.MustParseRFC3339("2021-07-01T00:00:00Z"),
EndTime: types.MustParseRFC3339("2021-07-10T00:00:00Z"),
},
{
Name: "test3",
Name: "test2",
Rate: sdk.MustNewDecFromStr("0.1"),
BudgetSourceAddress: tAddr2.String(),
CollectionAddress: cAddr2.String(),
StartTime: time.Time{},
EndTime: time.Time{},
StartTime: types.MustParseRFC3339("2021-07-01T00:00:00Z"),
EndTime: types.MustParseRFC3339("2021-07-10T00:00:00Z"),
},
{
Name: "test3",
Rate: sdk.MustNewDecFromStr("0.1"),
BudgetSourceAddress: tAddr2.String(),
CollectionAddress: cAddr2.String(),
StartTime: time.Time{},
EndTime: time.Time{},
StartTime: types.MustParseRFC3339("2021-08-01T00:00:00Z"),
EndTime: types.MustParseRFC3339("2021-08-10T00:00:00Z"),
},
{
Name: "test4",
Expand All @@ -78,23 +67,53 @@ func TestValidateBudgets(t *testing.T) {
EndTime: types.MustParseRFC3339("2021-08-25T00:00:00Z"),
},
}
)

func TestParams(t *testing.T) {
require.IsType(t, paramstypes.KeyTable{}, types.ParamKeyTable())

defaultParams := types.DefaultParams()

paramsStr := `epoch_blocks: 1
budgets: []
`
require.Equal(t, paramsStr, defaultParams.String())
}

err := types.ValidateBudgets(budgets[:2])
func TestValidateBudgets(t *testing.T) {
err := types.ValidateBudgets([]types.Budget{budgets[0], budgets[1]})
require.NoError(t, err)

err = types.ValidateBudgets(budgets[:3])
err = types.ValidateBudgets([]types.Budget{budgets[0], budgets[1], budgets[2]})
require.ErrorIs(t, err, types.ErrInvalidTotalBudgetRate)

err = types.ValidateBudgets(budgets[3:5])
err = types.ValidateBudgets([]types.Budget{budgets[1], budgets[4]})
require.NoError(t, err)

err = types.ValidateBudgets(budgets[4:6])
err = types.ValidateBudgets([]types.Budget{budgets[4], budgets[5]})
require.ErrorIs(t, err, types.ErrInvalidTotalBudgetRate)

err = types.ValidateBudgets(budgets)
err = types.ValidateBudgets([]types.Budget{budgets[3], budgets[3]})
require.ErrorIs(t, err, types.ErrDuplicateBudgetName)
}

func TestCollectibleBudgets(t *testing.T) {
collectibleBudgets := types.CollectibleBudgets([]types.Budget{budgets[0], budgets[1]}, types.MustParseRFC3339("2021-07-05T00:00:00Z"))
require.Len(t, collectibleBudgets, 1)

collectibleBudgets = types.CollectibleBudgets([]types.Budget{budgets[0], budgets[1], budgets[2]}, types.MustParseRFC3339("2021-07-05T00:00:00Z"))
require.Len(t, collectibleBudgets, 2)

collectibleBudgets = types.CollectibleBudgets([]types.Budget{budgets[4], budgets[5]}, types.MustParseRFC3339("2021-08-18T00:00:00Z"))
require.Len(t, collectibleBudgets, 1)

collectibleBudgets = types.CollectibleBudgets([]types.Budget{budgets[4], budgets[5]}, types.MustParseRFC3339("2021-08-19T00:00:00Z"))
require.Len(t, collectibleBudgets, 2)

collectibleBudgets = types.CollectibleBudgets([]types.Budget{budgets[4], budgets[5]}, types.MustParseRFC3339("2021-08-20T00:00:00Z"))
require.Len(t, collectibleBudgets, 1)
}

func TestValidateEpochBlocks(t *testing.T) {
err := types.ValidateEpochBlocks(uint32(0))
require.NoError(t, err)
Expand Down