Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add connection to transport context #4649

Merged
merged 7 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
}
done := make(chan struct{})
t := &http2Server{
ctx: context.Background(),
ctx: setConnection(context.Background(), conn),
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
Expand Down Expand Up @@ -1345,3 +1345,18 @@ func getJitter(v time.Duration) time.Duration {
j := grpcrand.Int63n(2*r) - r
return time.Duration(j)
}

type connectionKey struct{}

// GetConnection gets the connection from the context.
func GetConnection(ctx context.Context) net.Conn {
conn, _ := ctx.Value(connectionKey{}).(net.Conn)
return conn
}

// SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection.
func setConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}
16 changes: 3 additions & 13 deletions internal/xds/rbac/rbac_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

var getConnection = transport.GetConnection

// ChainEngine represents a chain of RBAC Engines, used to make authorization
// decisions on incoming RPCs.
type ChainEngine struct {
Expand Down Expand Up @@ -206,16 +209,3 @@ type rpcData struct {
// handshake.
certs []*x509.Certificate
}

type connectionKey struct{}

func getConnection(ctx context.Context) net.Conn {
conn, _ := ctx.Value(connectionKey{}).(net.Conn)
return conn
}

// SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC.
func SetConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}
7 changes: 6 additions & 1 deletion internal/xds/rbac/rbac_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ func (s) TestNewChainEngine(t *testing.T) {
// different types of data representing incoming RPC's (piped into a context),
// and verifies that it works as expected.
func (s) TestChainEngine(t *testing.T) {
defer func(gc func(ctx context.Context) net.Conn) {
getConnection = gc
}(getConnection)
tests := []struct {
name string
rbacConfigs []*v3rbacpb.RBAC
Expand Down Expand Up @@ -882,7 +885,9 @@ func (s) TestChainEngine(t *testing.T) {
}
conn := <-connCh
defer conn.Close()
ctx = SetConnection(ctx, conn)
getConnection = func(context.Context) net.Conn {
return conn
}
ctx = peer.NewContext(ctx, data.rpcData.peerInfo)
stream := &ServerTransportStreamWithMethod{
method: data.rpcData.fullMethod,
Expand Down
54 changes: 54 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7766,3 +7766,57 @@ func (cvd *credentialsVerifyDeadline) Clone() credentials.TransportCredentials {
func (cvd *credentialsVerifyDeadline) OverrideServerName(s string) error {
return nil
}

func unaryInterceptorVerifyConn(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
conn := transport.GetConnection(ctx)
if conn == nil {
return nil, status.Error(codes.NotFound, "connection was not in context")
}
return nil, status.Error(codes.OK, "")
}

// TestUnaryServerInterceptorGetsConnection tests whether the accepted conn on
// the server gets to any unary interceptors on the server side.
func (s) TestUnaryServerInterceptorGetsConnection(t *testing.T) {
ss := &stubserver.StubServer{}
if err := ss.Start([]grpc.ServerOption{grpc.UnaryInterceptor(unaryInterceptorVerifyConn)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v, want _, error code %s", err, codes.OK)
}
}

func streamingInterceptorVerifyConn(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
conn := transport.GetConnection(ss.Context())
if conn == nil {
return status.Error(codes.NotFound, "connection was not in context")
}
return status.Error(codes.OK, "")
}

// TestStreamingServerInterceptorGetsConnection tests whether the accepted conn on
// the server gets to any streaming interceptors on the server side.
func (s) TestStreamingServerInterceptorGetsConnection(t *testing.T) {
ss := &stubserver.StubServer{}
if err := ss.Start([]grpc.ServerOption{grpc.StreamInterceptor(streamingInterceptorVerifyConn)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

s, err := ss.Client.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{})
if err != nil {
t.Fatalf("ss.Client.StreamingOutputCall(_) = _, %v, want _, <nil>", err)
}
if _, err := s.Recv(); err != io.EOF {
t.Fatalf("ss.Client.StreamingInputCall(_) = _, %v, want _, %v", err, io.EOF)
}
}