Skip to content

Commit

Permalink
feat(stf/router): support backwards compat type URL in router (#21177)
Browse files Browse the repository at this point in the history
  • Loading branch information
testinginprod authored Aug 7, 2024
1 parent 90fd632 commit 4dc9469
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 18 deletions.
8 changes: 4 additions & 4 deletions server/v2/stf/stf.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ var Identity = []byte("stf")
type STF[T transaction.Tx] struct {
logger log.Logger

msgRouter Router
queryRouter Router
msgRouter coreRouterImpl
queryRouter coreRouterImpl

doPreBlock func(ctx context.Context, txs []T) error
doBeginBlock func(ctx context.Context) error
Expand Down Expand Up @@ -584,8 +584,8 @@ func newExecutionContext(
sender transaction.Identity,
state store.WriterMap,
execMode transaction.ExecMode,
msgRouter Router,
queryRouter Router,
msgRouter coreRouterImpl,
queryRouter coreRouterImpl,
) *executionContext {
meter := makeGasMeterFn(gas.NoGasLimit)
meteredState := makeGasMeteredStoreFn(meter, state)
Expand Down
63 changes: 49 additions & 14 deletions server/v2/stf/stf_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"

gogoproto "github.com/cosmos/gogoproto/proto"

Expand Down Expand Up @@ -61,7 +62,7 @@ func (b *MsgRouterBuilder) HandlerExists(msgType string) bool {
return ok
}

func (b *MsgRouterBuilder) Build() (Router, error) {
func (b *MsgRouterBuilder) Build() (coreRouterImpl, error) {
handlers := make(map[string]appmodulev2.Handler)

globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error {
Expand Down Expand Up @@ -93,7 +94,7 @@ func (b *MsgRouterBuilder) Build() (Router, error) {
handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler)
}

return Router{
return coreRouterImpl{
handlers: handlers,
}, nil
}
Expand Down Expand Up @@ -139,39 +140,73 @@ func msgTypeURL(msg gogoproto.Message) string {
return gogoproto.MessageName(msg)
}

var _ router.Service = (*Router)(nil)
var _ router.Service = (*coreRouterImpl)(nil)

// Router implements the STF router for msg and query handlers.
type Router struct {
// coreRouterImpl implements the STF router for msg and query handlers.
type coreRouterImpl struct {
handlers map[string]appmodulev2.Handler
}

func (r Router) CanInvoke(_ context.Context, typeURL string) error {
func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error {
// trimming prefixes is a backwards compatibility strategy that we use
// for baseapp components that did routing through type URL rather
// than protobuf message names.
typeURL = strings.TrimPrefix(typeURL, "/")
_, exists := r.handlers[typeURL]
if !exists {
return fmt.Errorf("%w: %s", ErrNoHandler, typeURL)
}
return nil
}

func (r Router) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
func (r coreRouterImpl) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
handlerResp, err := r.InvokeUntyped(ctx, req)
if err != nil {
return err
}
merge(handlerResp, resp)
return nil
}

func merge(src, dst gogoproto.Message) {
reflect.Indirect(reflect.ValueOf(dst)).Set(reflect.Indirect(reflect.ValueOf(src)))
return merge(handlerResp, resp)
}

func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
func (r coreRouterImpl) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
typeName := msgTypeURL(req)
handler, exists := r.handlers[typeName]
if !exists {
return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName)
}
return handler(ctx, req)
}

// merge merges together two protobuf messages by setting the pointer
// to src in dst. Used internally.
func merge(src, dst gogoproto.Message) error {
if src == nil {
return fmt.Errorf("source message is nil")
}
if dst == nil {
return fmt.Errorf("destination message is nil")
}

srcVal := reflect.ValueOf(src)
dstVal := reflect.ValueOf(dst)

if srcVal.Kind() == reflect.Interface {
srcVal = srcVal.Elem()
}
if dstVal.Kind() == reflect.Interface {
dstVal = dstVal.Elem()
}

if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr {
return fmt.Errorf("both source and destination must be pointers")
}

srcElem := srcVal.Elem()
dstElem := dstVal.Elem()

if !srcElem.Type().AssignableTo(dstElem.Type()) {
return fmt.Errorf("incompatible types: cannot merge %v into %v", srcElem.Type(), dstElem.Type())
}

dstElem.Set(srcElem)
return nil
}
107 changes: 107 additions & 0 deletions server/v2/stf/stf_router_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package stf

import (
"context"
"testing"

gogoproto "github.com/cosmos/gogoproto/proto"
gogotypes "github.com/cosmos/gogoproto/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"cosmossdk.io/core/appmodule/v2"
)

func TestRouter(t *testing.T) {
expectedMsg := &gogotypes.BoolValue{Value: true}
expectedMsgName := gogoproto.MessageName(expectedMsg)

expectedResp := &gogotypes.StringValue{Value: "test"}

router := coreRouterImpl{handlers: map[string]appmodule.Handler{
gogoproto.MessageName(expectedMsg): func(ctx context.Context, gotMsg appmodule.Message) (msgResp appmodule.Message, err error) {
require.Equal(t, expectedMsg, gotMsg)
return expectedResp, nil
},
}}

t.Run("can invoke message by name", func(t *testing.T) {
err := router.CanInvoke(context.Background(), expectedMsgName)
require.NoError(t, err, "must be invokable")
})

t.Run("can invoke message by type URL", func(t *testing.T) {
err := router.CanInvoke(context.Background(), "/"+expectedMsgName)
require.NoError(t, err)
})

t.Run("cannot invoke unknown message", func(t *testing.T) {
err := router.CanInvoke(context.Background(), "not exist")
require.Error(t, err)
})

t.Run("invoke untyped", func(t *testing.T) {
gotResp, err := router.InvokeUntyped(context.Background(), expectedMsg)
require.NoError(t, err)
require.Equal(t, expectedResp, gotResp)
})

t.Run("invoked typed", func(t *testing.T) {
gotResp := new(gogotypes.StringValue)
err := router.InvokeTyped(context.Background(), expectedMsg, gotResp)
require.NoError(t, err)
require.Equal(t, expectedResp, gotResp)
})
}

func TestMerge(t *testing.T) {
tests := []struct {
name string
src gogoproto.Message
dst gogoproto.Message
expected gogoproto.Message
wantErr bool
}{
{
name: "success",
src: &gogotypes.BoolValue{Value: true},
dst: &gogotypes.BoolValue{},
expected: &gogotypes.BoolValue{Value: true},
wantErr: false,
},
{
name: "nil src",
src: nil,
dst: &gogotypes.StringValue{},
expected: &gogotypes.StringValue{},
wantErr: true,
},
{
name: "nil dst",
src: &gogotypes.StringValue{Value: "hello"},
dst: nil,
expected: nil,
wantErr: true,
},
{
name: "incompatible types",
src: &gogotypes.StringValue{Value: "hello"},
dst: &gogotypes.BoolValue{},
expected: &gogotypes.BoolValue{},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := merge(tt.src, tt.dst)

if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, tt.dst)
}
})
}
}

0 comments on commit 4dc9469

Please sign in to comment.