From a1f28b09a7a99f656bc2e837cec450ddbc49ec59 Mon Sep 17 00:00:00 2001 From: Ooo0oO0o0oO <907709476@qq.com> Date: Mon, 27 Jan 2020 14:25:04 +0800 Subject: [PATCH] rename file --- common/extension/auth.go | 6 +++--- config/service_config.go | 8 +++++++- filter/{auth_ext.go => access_key.go} | 9 ++------- filter/authenticator.go | 11 +++++++++++ filter/filter_impl/auth/accesskey_storage.go | 4 ++-- filter/filter_impl/auth/accesskey_storage_test.go | 2 +- filter/filter_impl/auth/consumer_sign.go | 7 ++++--- filter/filter_impl/auth/consumer_sign_test.go | 4 ++-- .../{authenticator.go => default_authenticator.go} | 4 ++-- ...nticator_test.go => default_authenticator_test.go} | 0 filter/filter_impl/auth/provider_auth.go | 7 ++++--- filter/filter_impl/auth/provider_auth_test.go | 4 ++-- 12 files changed, 40 insertions(+), 26 deletions(-) rename filter/{auth_ext.go => access_key.go} (74%) create mode 100644 filter/authenticator.go rename filter/filter_impl/auth/{authenticator.go => default_authenticator.go} (98%) rename filter/filter_impl/auth/{authenticator_test.go => default_authenticator_test.go} (100%) diff --git a/common/extension/auth.go b/common/extension/auth.go index f3dbfbd0b5..e57e22f660 100644 --- a/common/extension/auth.go +++ b/common/extension/auth.go @@ -6,7 +6,7 @@ import ( var ( authenticators = make(map[string]func() filter.Authenticator) - accesskeyStorages = make(map[string]func() filter.AccesskeyStorage) + accesskeyStorages = make(map[string]func() filter.AccessKeyStorage) ) func SetAuthenticator(name string, fcn func() filter.Authenticator) { @@ -20,11 +20,11 @@ func GetAuthenticator(name string) filter.Authenticator { return authenticators[name]() } -func SetAccesskeyStorages(name string, fcn func() filter.AccesskeyStorage) { +func SetAccesskeyStorages(name string, fcn func() filter.AccessKeyStorage) { accesskeyStorages[name] = fcn } -func GetAccesskeyStorages(name string) filter.AccesskeyStorage { +func GetAccesskeyStorages(name string) filter.AccessKeyStorage { if accesskeyStorages[name] == nil { panic("accesskeyStorages for " + name + " is not existing, make sure you have import the package.") } diff --git a/config/service_config.go b/config/service_config.go index 37ec3a3ae6..72074d6072 100644 --- a/config/service_config.go +++ b/config/service_config.go @@ -67,6 +67,8 @@ type ServiceConfig struct { TpsLimitRejectedHandler string `yaml:"tps.limit.rejected.handler" json:"tps.limit.rejected.handler,omitempty" property:"tps.limit.rejected.handler"` ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"` ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"` + Auth string `yaml:"auth" json:"auth,omitempty" property:"auth"` + ParamSign string `yaml:"param.sign" json:"param.sign,omitempty" property:"param.sign"` unexported *atomic.Bool exported *atomic.Bool @@ -220,7 +222,11 @@ func (c *ServiceConfig) getUrlMap() url.Values { urlMap.Set(constant.EXECUTE_LIMIT_KEY, c.ExecuteLimit) urlMap.Set(constant.EXECUTE_REJECTED_EXECUTION_HANDLER_KEY, c.ExecuteLimitRejectedHandler) - for _, v := range c.Methods { + // auth filter + urlMap.Set(constant.SERVICE_AUTH_KEY, srvconfig.Auth) + urlMap.Set(constant.PARAMTER_SIGNATURE_ENABLE_KEY, srvconfig.ParamSign) + + for _, v := range srvconfig.Methods { prefix := "methods." + v.Name + "." urlMap.Set(prefix+constant.LOADBALANCE_KEY, v.Loadbalance) urlMap.Set(prefix+constant.RETRIES_KEY, v.Retries) diff --git a/filter/auth_ext.go b/filter/access_key.go similarity index 74% rename from filter/auth_ext.go rename to filter/access_key.go index 1497dfa034..36b709daf2 100644 --- a/filter/auth_ext.go +++ b/filter/access_key.go @@ -14,11 +14,6 @@ type AccessKeyPair struct { Options string `yaml:"options" json:"options,omitempty" property:"options"` } -type Authenticator interface { - Sign(protocol.Invocation, *common.URL) error - Authenticate(protocol.Invocation, *common.URL) error -} - -type AccesskeyStorage interface { - GetAccesskeyPair(protocol.Invocation, *common.URL) *AccessKeyPair +type AccessKeyStorage interface { + GetAccessKeyPair(protocol.Invocation, *common.URL) *AccessKeyPair } diff --git a/filter/authenticator.go b/filter/authenticator.go new file mode 100644 index 0000000000..b7def1cc98 --- /dev/null +++ b/filter/authenticator.go @@ -0,0 +1,11 @@ +package filter + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" +) + +type Authenticator interface { + Sign(protocol.Invocation, *common.URL) error + Authenticate(protocol.Invocation, *common.URL) error +} diff --git a/filter/filter_impl/auth/accesskey_storage.go b/filter/filter_impl/auth/accesskey_storage.go index c957cae6dc..924266a0d6 100644 --- a/filter/filter_impl/auth/accesskey_storage.go +++ b/filter/filter_impl/auth/accesskey_storage.go @@ -11,7 +11,7 @@ import ( type DefaultAccesskeyStorage struct { } -func (storage *DefaultAccesskeyStorage) GetAccesskeyPair(invocation protocol.Invocation, url *common.URL) *filter.AccessKeyPair { +func (storage *DefaultAccesskeyStorage) GetAccessKeyPair(invocation protocol.Invocation, url *common.URL) *filter.AccessKeyPair { return &filter.AccessKeyPair{ AccessKey: url.GetParam(constant.ACCESS_KEY_ID_KEY, ""), SecretKey: url.GetParam(constant.SECRET_ACCESS_KEY_KEY, ""), @@ -22,6 +22,6 @@ func init() { extension.SetAccesskeyStorages(constant.DEFAULT_ACCESS_KEY_STORAGE, GetDefaultAccesskeyStorage) } -func GetDefaultAccesskeyStorage() filter.AccesskeyStorage { +func GetDefaultAccesskeyStorage() filter.AccessKeyStorage { return &DefaultAccesskeyStorage{} } diff --git a/filter/filter_impl/auth/accesskey_storage_test.go b/filter/filter_impl/auth/accesskey_storage_test.go index 15d2ce5b84..6ab861a867 100644 --- a/filter/filter_impl/auth/accesskey_storage_test.go +++ b/filter/filter_impl/auth/accesskey_storage_test.go @@ -22,7 +22,7 @@ func TestDefaultAccesskeyStorage_GetAccesskeyPair(t *testing.T) { common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey")) invocation := &invocation2.RPCInvocation{} storage := GetDefaultAccesskeyStorage() - accesskeyPair := storage.GetAccesskeyPair(invocation, url) + accesskeyPair := storage.GetAccessKeyPair(invocation, url) assert.Equal(t, "skey", accesskeyPair.SecretKey) assert.Equal(t, "akey", accesskeyPair.AccessKey) } diff --git a/filter/filter_impl/auth/consumer_sign.go b/filter/filter_impl/auth/consumer_sign.go index 948706b04d..5feb289bf3 100644 --- a/filter/filter_impl/auth/consumer_sign.go +++ b/filter/filter_impl/auth/consumer_sign.go @@ -1,6 +1,7 @@ package auth import ( + "context" "fmt" ) import ( @@ -18,7 +19,7 @@ func init() { extension.SetFilter(constant.CONSUMER_SIGN_FILTER, getConsumerSignFilter) } -func (csf *ConsumerSignFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (csf *ConsumerSignFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { logger.Infof("invoking ConsumerSign filter.") url := invoker.GetUrl() @@ -29,10 +30,10 @@ func (csf *ConsumerSignFilter) Invoke(invoker protocol.Invoker, invocation proto panic(fmt.Sprintf("Sign for invocation %s # %s failed", url.ServiceKey(), invocation.MethodName())) } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (csf *ConsumerSignFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (csf *ConsumerSignFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } func getConsumerSignFilter() filter.Filter { diff --git a/filter/filter_impl/auth/consumer_sign_test.go b/filter/filter_impl/auth/consumer_sign_test.go index 23d6ba1b4e..c90a769bcb 100644 --- a/filter/filter_impl/auth/consumer_sign_test.go +++ b/filter/filter_impl/auth/consumer_sign_test.go @@ -30,8 +30,8 @@ func TestConsumerSignFilter_Invoke(t *testing.T) { result := &protocol.RPCResult{} invoker.EXPECT().Invoke(inv).Return(result).Times(2) invoker.EXPECT().GetUrl().Return(url).Times(2) - assert.Equal(t, result, filter.Invoke(invoker, inv)) + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) url.SetParam(constant.SERVICE_AUTH_KEY, "true") - assert.Equal(t, result, filter.Invoke(invoker, inv)) + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) } diff --git a/filter/filter_impl/auth/authenticator.go b/filter/filter_impl/auth/default_authenticator.go similarity index 98% rename from filter/filter_impl/auth/authenticator.go rename to filter/filter_impl/auth/default_authenticator.go index 2b14ac1a8f..0f0172f696 100644 --- a/filter/filter_impl/auth/authenticator.go +++ b/filter/filter_impl/auth/default_authenticator.go @@ -3,6 +3,7 @@ package auth import ( "errors" "fmt" + "github.com/apache/dubbo-go/filter" "strconv" "time" ) @@ -11,7 +12,6 @@ import ( "github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" - "github.com/apache/dubbo-go/filter" "github.com/apache/dubbo-go/protocol" invocation_impl "github.com/apache/dubbo-go/protocol/invocation" ) @@ -89,7 +89,7 @@ func (authenticator *DefaultAuthenticator) Authenticate(invocation protocol.Invo func getAccessKeyPair(invocation protocol.Invocation, url *common.URL) (*filter.AccessKeyPair, error) { accesskeyStorage := extension.GetAccesskeyStorages(url.GetParam(constant.ACCESS_KEY_STORAGE_KEY, constant.DEFAULT_ACCESS_KEY_STORAGE)) - accessKeyPair := accesskeyStorage.GetAccesskeyPair(invocation, url) + accessKeyPair := accesskeyStorage.GetAccessKeyPair(invocation, url) if accessKeyPair == nil || IsEmpty(accessKeyPair.AccessKey, false) || IsEmpty(accessKeyPair.SecretKey, true) { return nil, errors.New("accessKeyId or secretAccessKey not found") } else { diff --git a/filter/filter_impl/auth/authenticator_test.go b/filter/filter_impl/auth/default_authenticator_test.go similarity index 100% rename from filter/filter_impl/auth/authenticator_test.go rename to filter/filter_impl/auth/default_authenticator_test.go diff --git a/filter/filter_impl/auth/provider_auth.go b/filter/filter_impl/auth/provider_auth.go index a7f872503a..c13f342534 100644 --- a/filter/filter_impl/auth/provider_auth.go +++ b/filter/filter_impl/auth/provider_auth.go @@ -1,6 +1,7 @@ package auth import ( + "context" "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" "github.com/apache/dubbo-go/common/logger" @@ -15,7 +16,7 @@ func init() { extension.SetFilter(constant.PROVIDER_AUTH_FILTER, getProviderAuthFilter) } -func (paf *ProviderAuthFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (paf *ProviderAuthFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { logger.Infof("invoking providerAuth filter.") url := invoker.GetUrl() @@ -29,10 +30,10 @@ func (paf *ProviderAuthFilter) Invoke(invoker protocol.Invoker, invocation proto } } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (paf *ProviderAuthFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (paf *ProviderAuthFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } func getProviderAuthFilter() filter.Filter { diff --git a/filter/filter_impl/auth/provider_auth_test.go b/filter/filter_impl/auth/provider_auth_test.go index 56c3f6121b..7552a4aa04 100644 --- a/filter/filter_impl/auth/provider_auth_test.go +++ b/filter/filter_impl/auth/provider_auth_test.go @@ -50,8 +50,8 @@ func TestProviderAuthFilter_Invoke(t *testing.T) { result := &protocol.RPCResult{} invoker.EXPECT().Invoke(inv).Return(result).Times(2) invoker.EXPECT().GetUrl().Return(url).Times(2) - assert.Equal(t, result, filter.Invoke(invoker, inv)) + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) url.SetParam(constant.SERVICE_AUTH_KEY, "true") - assert.Equal(t, result, filter.Invoke(invoker, inv)) + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) }