From 12530dd7889210aac6c07f2f46f47635fcdd08b4 Mon Sep 17 00:00:00 2001
From: Ivan <2103732+codebien@users.noreply.github.com>
Date: Tue, 15 Mar 2022 09:22:37 +0100
Subject: [PATCH] lib/netext/grpcext: Split the gRPC module
It splits the gRPC business logic from the js/k6/grpc module to a netext
package.
It will be easier to extend the networking+metric features for gRPC in
isolated environment and/or create extension not strictly depending on
the js feature.
---
js/modules/k6/grpc/client.go | 427 ++++--------------------------
js/modules/k6/grpc/client_test.go | 130 +--------
js/modules/k6/grpc/grpc.go | 20 --
lib/netext/grpcext/conn.go | 305 +++++++++++++++++++++
lib/netext/grpcext/conn_test.go | 300 +++++++++++++++++++++
lib/netext/grpcext/reflect.go | 117 ++++++++
6 files changed, 776 insertions(+), 523 deletions(-)
create mode 100644 lib/netext/grpcext/conn.go
create mode 100644 lib/netext/grpcext/conn_test.go
create mode 100644 lib/netext/grpcext/reflect.go
diff --git a/js/modules/k6/grpc/client.go b/js/modules/k6/grpc/client.go
index 4d975730a6a2..c9e520acdcc4 100644
--- a/js/modules/k6/grpc/client.go
+++ b/js/modules/k6/grpc/client.go
@@ -1,94 +1,37 @@
-/*
- *
- * k6 - a next-generation load testing tool
- * Copyright (C) 2020 Load Impact
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as
- * published by the Free Software Foundation, either version 3 of the
- * License, or (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- *
- */
-
package grpc
import (
"context"
- "encoding/json"
"errors"
"fmt"
"io"
- "net"
- "strconv"
"strings"
"time"
+ "go.k6.io/k6/js/common"
+ "go.k6.io/k6/js/modules"
+ "go.k6.io/k6/lib/netext/grpcext"
+ "go.k6.io/k6/lib/types"
+ "go.k6.io/k6/metrics"
+
"github.com/dop251/goja"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
- "github.com/sirupsen/logrus"
"google.golang.org/grpc"
- "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
- grpcstats "google.golang.org/grpc/stats"
- "google.golang.org/grpc/status"
- "google.golang.org/protobuf/encoding/protojson"
- "google.golang.org/protobuf/encoding/prototext"
- "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
- "google.golang.org/protobuf/types/dynamicpb"
-
- //nolint: staticcheck
- protoV1 "github.com/golang/protobuf/proto"
-
- "go.k6.io/k6/js/common"
- "go.k6.io/k6/js/modules"
- "go.k6.io/k6/lib/types"
- "go.k6.io/k6/metrics"
- reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
-)
-
-//nolint: lll
-var (
- errInvokeRPCInInitContext = common.NewInitContextError("invoking RPC methods in the init context is not supported")
- errConnectInInitContext = common.NewInitContextError("connecting to a gRPC server in the init context is not supported")
)
// Client represents a gRPC client that can be used to make RPC requests
type Client struct {
mds map[string]protoreflect.MethodDescriptor
- conn *grpc.ClientConn
-
- vu modules.VU
-}
-
-// MethodInfo holds information on any parsed method descriptors that can be used by the goja VM
-type MethodInfo struct {
- grpc.MethodInfo `json:"-" js:"-"`
- Package string
- Service string
- FullMethod string
-}
-
-// Response is a gRPC response that can be used by the goja VM
-type Response struct {
- Status codes.Code
- Message interface{}
- Headers map[string][]string
- Trailers map[string][]string
- Error interface{}
+ conn *grpcext.Conn
+ vu modules.VU
+ addr string
}
// Load will parse the given proto files and make the file descriptors available to request.
@@ -134,7 +77,7 @@ func (c *Client) Load(importPaths []string, filenames ...string) ([]MethodInfo,
func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) {
state := c.vu.State()
if state == nil {
- return false, errConnectInInitContext
+ return false, common.NewInitContextError("connecting to a gRPC server in the init context is not supported")
}
p, err := c.parseConnectParams(params)
@@ -142,31 +85,29 @@ func (c *Client) Connect(addr string, params map[string]interface{}) (bool, erro
return false, err
}
- opts := make([]grpc.DialOption, 0, 2)
+ opts := grpcext.DefaultOptions(c.vu)
+ var tcred credentials.TransportCredentials
if !p.IsPlaintext {
tlsCfg := state.TLSConfig.Clone()
tlsCfg.NextProtos = []string{"h2"}
// TODO(rogchap): Would be good to add support for custom RootCAs (self signed)
-
- opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)))
+ tcred = credentials.NewTLS(tlsCfg)
} else {
- opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
+ tcred = insecure.NewCredentials()
}
+ opts = append(opts, grpc.WithTransportCredentials(tcred))
if ua := state.Options.UserAgent; ua.Valid {
opts = append(opts, grpc.WithUserAgent(ua.ValueOrZero()))
}
- dialer := func(ctx context.Context, addr string) (net.Conn, error) {
- return state.Dialer.DialContext(ctx, "tcp", addr)
- }
- opts = append(opts, grpc.WithContextDialer(dialer))
-
ctx, cancel := context.WithTimeout(c.vu.Context(), p.Timeout)
defer cancel()
+ c.addr = addr
+
err = c.dial(ctx, addr, p.UseReflectionProtocol, opts...)
return err != nil, err
}
@@ -176,11 +117,10 @@ func (c *Client) Invoke(
method string,
req goja.Value,
params map[string]interface{},
-) (*Response, error) {
- rt := c.vu.Runtime()
+) (*grpcext.Response, error) {
state := c.vu.State()
if state == nil {
- return nil, errInvokeRPCInInitContext
+ return nil, common.NewInitContextError("invoking RPC methods in the init context is not supported")
}
if c.conn == nil {
return nil, errors.New("no gRPC connection, you must call connect first")
@@ -191,8 +131,8 @@ func (c *Client) Invoke(
if method[0] != '/' {
method = "/" + method
}
- md := c.mds[method]
- if md == nil {
+ methodDesc := c.mds[method]
+ if methodDesc == nil {
return nil, fmt.Errorf("method %q not found in file descriptors", method)
}
@@ -201,18 +141,26 @@ func (c *Client) Invoke(
return nil, err
}
- ctx := metadata.NewOutgoingContext(c.vu.Context(), metadata.New(nil))
+ b, err := req.ToObject(c.vu.Runtime()).MarshalJSON()
+ if err != nil {
+ return nil, fmt.Errorf("unable to serialise request object: %w", err)
+ }
+
+ md := metadata.New(nil)
for param, strval := range p.Metadata {
- ctx = metadata.AppendToOutgoingContext(ctx, param, strval)
+ md.Append(param, strval)
}
+ ctx, cancel := context.WithTimeout(c.vu.Context(), p.Timeout)
+ defer cancel()
+
tags := state.CloneTags()
for k, v := range p.Tags {
tags[k] = v
}
if state.Options.SystemTags.Has(metrics.TagURL) {
- tags["url"] = fmt.Sprintf("%s%s", c.conn.Target(), method)
+ tags["url"] = fmt.Sprintf("%s%s", c.addr, method)
}
parts := strings.Split(method[1:], "/")
if state.Options.SystemTags.Has(metrics.TagService) {
@@ -227,74 +175,18 @@ func (c *Client) Invoke(
tags["name"] = method
}
- ctx = withTags(ctx, tags)
-
- reqdm := dynamicpb.NewMessage(md.Input())
- {
- b, err := req.ToObject(rt).MarshalJSON()
- if err != nil {
- return nil, fmt.Errorf("unable to serialise request object: %w", err)
- }
- if err := protojson.Unmarshal(b, reqdm); err != nil {
- return nil, fmt.Errorf("unable to serialise request object to protocol buffer: %w", err)
- }
+ reqmsg := grpcext.Request{
+ MethodDescriptor: methodDesc,
+ Message: b,
+ Tags: tags,
}
- reqCtx, cancel := context.WithTimeout(ctx, p.Timeout)
- defer cancel()
-
- resp := dynamicpb.NewMessage(md.Output())
- header, trailer := metadata.New(nil), metadata.New(nil)
- err = c.conn.Invoke(reqCtx, method, reqdm, resp, grpc.Header(&header), grpc.Trailer(&trailer))
-
- var response Response
- response.Headers = header
- response.Trailers = trailer
-
- marshaler := protojson.MarshalOptions{EmitUnpopulated: true}
-
- if err != nil {
- sterr := status.Convert(err)
- response.Status = sterr.Code()
-
- // (rogchap) when you access a JSON property in goja, you are actually accessing the underling
- // Go type (struct, map, slice etc); because these are dynamic messages the Unmarshaled JSON does
- // not map back to a "real" field or value (as a normal Go type would). If we don't marshal and then
- // unmarshal back to a map, you will get "undefined" when accessing JSON properties, even when
- // JSON.Stringify() shows the object to be correctly present.
-
- raw, _ := marshaler.Marshal(sterr.Proto())
- errMsg := make(map[string]interface{})
- _ = json.Unmarshal(raw, &errMsg)
- response.Error = errMsg
- }
-
- if resp != nil {
- // (rogchap) there is a lot of marshaling/unmarshaling here, but if we just pass the dynamic message
- // the default Marshaller would be used, which would strip any zero/default values from the JSON.
- // eg. given this message:
- // message Point {
- // double x = 1;
- // double y = 2;
- // double z = 3;
- // }
- // and a value like this:
- // msg := Point{X: 6, Y: 4, Z: 0}
- // would result in JSON output:
- // {"x":6,"y":4}
- // rather than the desired:
- // {"x":6,"y":4,"z":0}
- raw, _ := marshaler.Marshal(resp)
- msg := make(map[string]interface{})
- _ = json.Unmarshal(raw, &msg)
- response.Message = msg
- }
- return &response, nil
+ return c.conn.Invoke(ctx, method, md, reqmsg)
}
// Close will close the client gRPC connection
func (c *Client) Close() error {
- if c == nil || c.conn == nil {
+ if c.conn == nil {
return nil
}
err := c.conn.Close()
@@ -303,6 +195,14 @@ func (c *Client) Close() error {
return err
}
+// MethodInfo holds information on any parsed method descriptors that can be used by the goja VM
+type MethodInfo struct {
+ Package string
+ Service string
+ FullMethod string
+ grpc.MethodInfo `json:"-" js:"-"`
+}
+
func (c *Client) convertToMethodInfo(fdset *descriptorpb.FileDescriptorSet) ([]MethodInfo, error) {
files, err := protodesc.NewFiles(fdset)
if err != nil {
@@ -353,16 +253,8 @@ func (c *Client) dial(
reflect bool,
options ...grpc.DialOption,
) error {
- opts := []grpc.DialOption{
- grpc.WithBlock(),
- grpc.FailOnNonTempDialError(true),
- grpc.WithStatsHandler(statsHandler{vu: c.vu}),
- grpc.WithReturnConnectionError(),
- }
- opts = append(opts, options...)
-
var err error
- c.conn, err = grpc.DialContext(ctx, addr, opts...)
+ c.conn, err = grpcext.Dial(ctx, addr, options...)
if err != nil {
return err
}
@@ -370,32 +262,13 @@ func (c *Client) dial(
if !reflect {
return nil
}
-
- return c.reflect(ctx)
-}
-
-// reflect will use the grpc reflection api to make the file descriptors available to request.
-// It is called in the connect function the first time the Client.Connect function is called.
-func (c *Client) reflect(ctx context.Context) error {
- client := reflectpb.NewServerReflectionClient(c.conn)
- methodClient, err := client.ServerReflectionInfo(ctx)
+ rc, err := c.conn.ReflectionClient()
if err != nil {
- return fmt.Errorf("can't get server info: %w", err)
- }
- req := &reflectpb.ServerReflectionRequest{
- MessageRequest: &reflectpb.ServerReflectionRequest_ListServices{},
- }
- resp, err := sendReceive(methodClient, req)
- if err != nil {
- return fmt.Errorf("can't list services: %w", err)
- }
- listResp := resp.GetListServicesResponse()
- if listResp == nil {
- return fmt.Errorf("can't list services, nil response")
+ return err
}
- fdset, err := resolveServiceFileDescriptors(methodClient, listResp)
+ fdset, err := rc.Reflect(ctx)
if err != nil {
- return fmt.Errorf("can't resolve services' file descriptors: %w", err)
+ return err
}
_, err = c.convertToMethodInfo(fdset)
if err != nil {
@@ -501,140 +374,6 @@ func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams,
return params, nil
}
-type statsHandler struct {
- vu modules.VU
-}
-
-// TagConn implements the grpcstats.Handler interface
-func (statsHandler) TagConn(ctx context.Context, _ *grpcstats.ConnTagInfo) context.Context {
- // noop
- return ctx
-}
-
-// HandleConn implements the grpcstats.Handler interface
-func (statsHandler) HandleConn(context.Context, grpcstats.ConnStats) {
- // noop
-}
-
-// TagRPC implements the grpcstats.Handler interface
-func (statsHandler) TagRPC(ctx context.Context, _ *grpcstats.RPCTagInfo) context.Context {
- // noop
- return ctx
-}
-
-// HandleRPC implements the grpcstats.Handler interface
-func (h statsHandler) HandleRPC(ctx context.Context, stat grpcstats.RPCStats) {
- state := h.vu.State()
- tags := getTags(ctx)
- switch s := stat.(type) {
- case *grpcstats.OutHeader:
- if state.Options.SystemTags.Has(metrics.TagIP) && s.RemoteAddr != nil {
- if ip, _, err := net.SplitHostPort(s.RemoteAddr.String()); err == nil {
- tags["ip"] = ip
- }
- }
- case *grpcstats.End:
- if state.Options.SystemTags.Has(metrics.TagStatus) {
- tags["status"] = strconv.Itoa(int(status.Code(s.Error)))
- }
-
- mTags := map[string]string(tags)
- sampleTags := metrics.IntoSampleTags(&mTags)
- metrics.PushIfNotDone(ctx, state.Samples, metrics.ConnectedSamples{
- Samples: []metrics.Sample{
- {
- Metric: state.BuiltinMetrics.GRPCReqDuration,
- Tags: sampleTags,
- Value: metrics.D(s.EndTime.Sub(s.BeginTime)),
- Time: s.EndTime,
- },
- },
- })
- }
-
- // (rogchap) Re-using --http-debug flag as gRPC is technically still HTTP
- if state.Options.HTTPDebug.String != "" {
- logger := state.Logger.WithField("source", "http-debug")
- httpDebugOption := state.Options.HTTPDebug.String
- debugStat(stat, logger, httpDebugOption)
- }
-}
-
-// sendReceiver is a smaller interface for decoupling
-// from `reflectpb.ServerReflection_ServerReflectionInfoClient`,
-// that has the dependency from `grpc.ClientStream`,
-// which is too much in the case the requirement is to just make a reflection's request.
-// It makes the API more restricted and with a controlled surface,
-// in this way the testing should be easier also.
-type sendReceiver interface {
- Send(*reflectpb.ServerReflectionRequest) error
- Recv() (*reflectpb.ServerReflectionResponse, error)
-}
-
-// sendReceive sends a request to a reflection client and,
-// receives a response.
-func sendReceive(
- client sendReceiver,
- req *reflectpb.ServerReflectionRequest,
-) (*reflectpb.ServerReflectionResponse, error) {
- if err := client.Send(req); err != nil {
- return nil, fmt.Errorf("can't send request: %w", err)
- }
- resp, err := client.Recv()
- if err != nil {
- return nil, fmt.Errorf("can't receive response: %w", err)
- }
- return resp, nil
-}
-
-type fileDescriptorLookupKey struct {
- Package string
- Name string
-}
-
-func resolveServiceFileDescriptors(
- client sendReceiver,
- res *reflectpb.ListServiceResponse,
-) (*descriptorpb.FileDescriptorSet, error) {
- services := res.GetService()
- seen := make(map[fileDescriptorLookupKey]bool, len(services))
- fdset := &descriptorpb.FileDescriptorSet{
- File: make([]*descriptorpb.FileDescriptorProto, 0, len(services)),
- }
-
- for _, service := range services {
- req := &reflectpb.ServerReflectionRequest{
- MessageRequest: &reflectpb.ServerReflectionRequest_FileContainingSymbol{
- FileContainingSymbol: service.GetName(),
- },
- }
- resp, err := sendReceive(client, req)
- if err != nil {
- return nil, fmt.Errorf("can't get method on service %q: %w", service, err)
- }
- fdResp := resp.GetFileDescriptorResponse()
- for _, raw := range fdResp.GetFileDescriptorProto() {
- var fdp descriptorpb.FileDescriptorProto
- if err = proto.Unmarshal(raw, &fdp); err != nil {
- return nil, fmt.Errorf("can't unmarshal proto on service %q: %w", service, err)
- }
- fdkey := fileDescriptorLookupKey{
- Package: *fdp.Package,
- Name: *fdp.Name,
- }
- if seen[fdkey] {
- // When a proto file contains declarations for multiple services
- // then the same proto file is returned multiple times,
- // this prevents adding the returned proto file as a duplicate.
- continue
- }
- seen[fdkey] = true
- fdset.File = append(fdset.File, &fdp)
- }
- }
- return fdset, nil
-}
-
func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto {
fds := []*descriptorpb.FileDescriptorProto{}
@@ -651,67 +390,3 @@ func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*d
return fds
}
-
-func debugStat(stat grpcstats.RPCStats, logger logrus.FieldLogger, httpDebugOption string) {
- switch s := stat.(type) {
- case *grpcstats.OutHeader:
- logger.Infof("Out Header:\nFull Method: %s\nRemote Address: %s\n%s\n",
- s.FullMethod, s.RemoteAddr, formatMetadata(s.Header))
- case *grpcstats.OutTrailer:
- if len(s.Trailer) > 0 {
- logger.Infof("Out Trailer:\n%s\n", formatMetadata(s.Trailer))
- }
- case *grpcstats.OutPayload:
- if httpDebugOption == "full" {
- logger.Infof("Out Payload:\nWire Length: %d\nSent Time: %s\n%s\n\n",
- s.WireLength, s.SentTime, formatPayload(s.Payload))
- }
- case *grpcstats.InHeader:
- if len(s.Header) > 0 {
- logger.Infof("In Header:\nWire Length: %d\n%s\n", s.WireLength, formatMetadata(s.Header))
- }
- case *grpcstats.InTrailer:
- if len(s.Trailer) > 0 {
- logger.Infof("In Trailer:\nWire Length: %d\n%s\n", s.WireLength, formatMetadata(s.Trailer))
- }
- case *grpcstats.InPayload:
- if httpDebugOption == "full" {
- logger.Infof("In Payload:\nWire Length: %d\nReceived Time: %s\n%s\n\n",
- s.WireLength, s.RecvTime, formatPayload(s.Payload))
- }
- }
-}
-
-func formatMetadata(md metadata.MD) string {
- var sb strings.Builder
- for k, v := range md {
- sb.WriteString(k)
- sb.WriteString(": ")
- sb.WriteString(strings.Join(v, ", "))
- sb.WriteRune('\n')
- }
-
- return sb.String()
-}
-
-func formatPayload(payload interface{}) string {
- msg, ok := payload.(proto.Message)
- if !ok {
- // check to see if we are dealing with a APIv1 message
- msgV1, ok := payload.(protoV1.Message)
- if !ok {
- return ""
- }
- msg = protoV1.MessageV2(msgV1)
- }
-
- marshaler := prototext.MarshalOptions{
- Multiline: true,
- Indent: " ",
- }
- b, err := marshaler.Marshal(msg)
- if err != nil {
- return ""
- }
- return string(b)
-}
diff --git a/js/modules/k6/grpc/client_test.go b/js/modules/k6/grpc/client_test.go
index d3dfe8ad9850..22dabe124645 100644
--- a/js/modules/k6/grpc/client_test.go
+++ b/js/modules/k6/grpc/client_test.go
@@ -1,23 +1,3 @@
-/*
- *
- * k6 - a next-generation load testing tool
- * Copyright (C) 2020 Load Impact
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as
- * published by the Free Software Foundation, either version 3 of the
- * License, or (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- *
- */
-
package grpc
import (
@@ -25,19 +5,14 @@ import (
"context"
"crypto/tls"
"errors"
- "fmt"
"io/ioutil"
"net/url"
"os"
"runtime"
"strings"
- "sync/atomic"
"testing"
"google.golang.org/grpc/reflection"
- reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/types/descriptorpb"
"github.com/dop251/goja"
"github.com/sirupsen/logrus"
@@ -56,6 +31,7 @@ import (
"go.k6.io/k6/js/modulestest"
"go.k6.io/k6/lib"
"go.k6.io/k6/lib/fsext"
+ "go.k6.io/k6/lib/netext/grpcext"
"go.k6.io/k6/lib/testutils"
"go.k6.io/k6/lib/testutils/httpmultibin"
"go.k6.io/k6/metrics"
@@ -388,7 +364,7 @@ func TestClient(t *testing.T) {
client.connect("GRPCBIN_ADDR");
var resp = client.invoke("grpc.testing.TestService/EmptyCall", {})
if (resp.status !== grpc.StatusOK) {
- throw new Error("unexpected error status: " + resp.status)
+ throw new Error("unexpected error: " + JSON.stringify(resp.error) + "or status: " + resp.status)
}`,
asserts: func(t *testing.T, rb *httpmultibin.HTTPMultiBin, samples chan metrics.SampleContainer, _ error) {
samplesBuf := metrics.GetBufferedSamples(samples)
@@ -788,7 +764,7 @@ func TestDebugStat(t *testing.T) {
logger := logrus.New()
logger.Out = &b
- debugStat(tt.stat, logger.WithField("source", "test"), "full")
+ grpcext.DebugStat(logger.WithField("source", "test"), tt.stat, "full")
assert.Contains(t, b.String(), tt.expected)
})
}
@@ -823,103 +799,3 @@ func TestClientInvokeHeadersDeprecated(t *testing.T) {
require.Len(t, entries, 1)
require.Contains(t, entries[0].Message, "headers property is deprecated")
}
-
-func TestResolveFileDescriptors(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- pkgs []string
- services []string
- expectedDescriptors int
- }{
- {
- name: "SuccessSamePackage",
- pkgs: []string{"mypkg"},
- services: []string{"Service1", "Service2", "Service3"},
- expectedDescriptors: 3,
- },
- {
- name: "SuccessMultiPackages",
- pkgs: []string{"mypkg1", "mypkg2", "mypkg3"},
- services: []string{"Service", "Service", "Service"},
- expectedDescriptors: 3,
- },
- {
- name: "DeduplicateServices",
- pkgs: []string{"mypkg1"},
- services: []string{"Service1", "Service2", "Service1"},
- expectedDescriptors: 2,
- },
- {
- name: "NoServices",
- services: []string{},
- expectedDescriptors: 0,
- },
- }
-
- for _, tt := range tests {
- tt := tt
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- var (
- lsr = &reflectpb.ListServiceResponse{}
- mock = &getServiceFileDescriptorMock{}
- )
- for i, service := range tt.services {
- // if only one package is defined then
- // the package is the same for every service
- pkg := tt.pkgs[0]
- if len(tt.pkgs) > 1 {
- pkg = tt.pkgs[i]
- }
-
- lsr.Service = append(lsr.Service, &reflectpb.ServiceResponse{
- Name: fmt.Sprintf("%s.%s", pkg, service),
- })
- mock.pkgs = append(mock.pkgs, pkg)
- mock.names = append(mock.names, service)
- }
-
- fdset, err := resolveServiceFileDescriptors(mock, lsr)
- require.NoError(t, err)
- assert.Len(t, fdset.File, tt.expectedDescriptors)
- })
- }
-}
-
-type getServiceFileDescriptorMock struct {
- nreqs int64
- pkgs []string
- names []string
-}
-
-func (m *getServiceFileDescriptorMock) Send(req *reflectpb.ServerReflectionRequest) error {
- // TODO: check that the sent message is expected,
- // otherwise return an error
- return nil
-}
-
-func (m *getServiceFileDescriptorMock) Recv() (*reflectpb.ServerReflectionResponse, error) {
- n := atomic.AddInt64(&m.nreqs, 1)
- ptr := func(s string) (sptr *string) {
- return &s
- }
- index := n - 1
- fdp := &descriptorpb.FileDescriptorProto{
- Package: ptr(m.pkgs[index]),
- Name: ptr(m.names[index]),
- }
- b, err := proto.Marshal(fdp)
- if err != nil {
- return nil, err
- }
- srr := &reflectpb.ServerReflectionResponse{
- MessageResponse: &reflectpb.ServerReflectionResponse_FileDescriptorResponse{
- FileDescriptorResponse: &reflectpb.FileDescriptorResponse{
- FileDescriptorProto: [][]byte{b},
- },
- },
- }
- return srr, nil
-}
diff --git a/js/modules/k6/grpc/grpc.go b/js/modules/k6/grpc/grpc.go
index cd0babb7cc6c..46ab0c923b6e 100644
--- a/js/modules/k6/grpc/grpc.go
+++ b/js/modules/k6/grpc/grpc.go
@@ -1,23 +1,3 @@
-/*
- *
- * k6 - a next-generation load testing tool
- * Copyright (C) 2020 Load Impact
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as
- * published by the Free Software Foundation, either version 3 of the
- * License, or (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- *
- */
-
package grpc
import (
diff --git a/lib/netext/grpcext/conn.go b/lib/netext/grpcext/conn.go
new file mode 100644
index 000000000000..976b3c6394f1
--- /dev/null
+++ b/lib/netext/grpcext/conn.go
@@ -0,0 +1,305 @@
+// Package grpcext allows gRPC requests collecting stats info.
+package grpcext
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "go.k6.io/k6/js/modules"
+ "go.k6.io/k6/metrics"
+
+ protov1 "github.com/golang/protobuf/proto" //nolint:staticcheck,nolintlint // this is the old v1 version
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+ grpcstats "google.golang.org/grpc/stats"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/encoding/protojson"
+ "google.golang.org/protobuf/encoding/prototext"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/types/dynamicpb"
+)
+
+// Request represents a gRPC request.
+type Request struct {
+ MethodDescriptor protoreflect.MethodDescriptor
+ Tags map[string]string
+ Message []byte
+}
+
+// Response represents a gRPC response.
+type Response struct {
+ Message interface{}
+ Error interface{}
+ Headers map[string][]string
+ Trailers map[string][]string
+ Status codes.Code
+}
+
+type clientConnCloser interface {
+ grpc.ClientConnInterface
+ Close() error
+}
+
+// Conn is a gRPC client connection.
+type Conn struct {
+ raw clientConnCloser
+}
+
+// DefaultOptions generates an option set
+// with common options for requests from a VU.
+func DefaultOptions(vu modules.VU) []grpc.DialOption {
+ dialer := func(ctx context.Context, addr string) (net.Conn, error) {
+ return vu.State().Dialer.DialContext(ctx, "tcp", addr)
+ }
+
+ return []grpc.DialOption{
+ grpc.WithBlock(),
+ grpc.FailOnNonTempDialError(true),
+ grpc.WithReturnConnectionError(),
+ grpc.WithStatsHandler(statsHandler{vu: vu}),
+ grpc.WithContextDialer(dialer),
+ }
+}
+
+// Dial establish a gRPC connection.
+func Dial(ctx context.Context, addr string, options ...grpc.DialOption) (*Conn, error) {
+ conn, err := grpc.DialContext(ctx, addr, options...)
+ if err != nil {
+ return nil, err
+ }
+ return &Conn{
+ raw: conn,
+ }, nil
+}
+
+// ReflectionClient returns a reflection client based on the current connection.
+func (c *Conn) ReflectionClient() (*ReflectionClient, error) {
+ return &ReflectionClient{Conn: c.raw}, nil
+}
+
+// Invoke executes a unary gRPC request.
+//
+// TODO: should we support grpc.CallOption?
+func (c *Conn) Invoke(ctx context.Context, method string, md metadata.MD, req Request) (*Response, error) {
+ if method == "" {
+ return nil, fmt.Errorf("method is required")
+ }
+ if req.MethodDescriptor == nil {
+ return nil, fmt.Errorf("request method descriptor is required")
+ }
+ if len(req.Message) == 0 {
+ return nil, fmt.Errorf("request message is required")
+ }
+
+ ctx = metadata.NewOutgoingContext(ctx, md)
+
+ reqdm := dynamicpb.NewMessage(req.MethodDescriptor.Input())
+ if err := protojson.Unmarshal(req.Message, reqdm); err != nil {
+ return nil, fmt.Errorf("unable to serialise request object to protocol buffer: %w", err)
+ }
+
+ ctx = withTags(ctx, req.Tags)
+
+ resp := dynamicpb.NewMessage(req.MethodDescriptor.Output())
+ header, trailer := metadata.New(nil), metadata.New(nil)
+ err := c.raw.Invoke(ctx, method, reqdm, resp, grpc.Header(&header), grpc.Trailer(&trailer))
+
+ response := Response{
+ Headers: header,
+ Trailers: trailer,
+ }
+
+ marshaler := protojson.MarshalOptions{EmitUnpopulated: true}
+
+ if err != nil {
+ sterr := status.Convert(err)
+ response.Status = sterr.Code()
+
+ // (rogchap) when you access a JSON property in goja, you are actually accessing the underling
+ // Go type (struct, map, slice etc); because these are dynamic messages the Unmarshaled JSON does
+ // not map back to a "real" field or value (as a normal Go type would). If we don't marshal and then
+ // unmarshal back to a map, you will get "undefined" when accessing JSON properties, even when
+ // JSON.Stringify() shows the object to be correctly present.
+
+ raw, _ := marshaler.Marshal(sterr.Proto())
+ errMsg := make(map[string]interface{})
+ _ = json.Unmarshal(raw, &errMsg)
+ response.Error = errMsg
+ }
+
+ if resp != nil {
+ // (rogchap) there is a lot of marshaling/unmarshaling here, but if we just pass the dynamic message
+ // the default Marshaller would be used, which would strip any zero/default values from the JSON.
+ // eg. given this message:
+ // message Point {
+ // double x = 1;
+ // double y = 2;
+ // double z = 3;
+ // }
+ // and a value like this:
+ // msg := Point{X: 6, Y: 4, Z: 0}
+ // would result in JSON output:
+ // {"x":6,"y":4}
+ // rather than the desired:
+ // {"x":6,"y":4,"z":0}
+ raw, _ := marshaler.Marshal(resp)
+ msg := make(map[string]interface{})
+ _ = json.Unmarshal(raw, &msg)
+ response.Message = msg
+ }
+ return &response, nil
+}
+
+// Close closes the underhood connection.
+func (c *Conn) Close() error {
+ return c.raw.Close()
+}
+
+type statsHandler struct {
+ vu modules.VU
+}
+
+// TagConn implements the grpcstats.Handler interface
+func (statsHandler) TagConn(ctx context.Context, _ *grpcstats.ConnTagInfo) context.Context { // noop
+ return ctx
+}
+
+// HandleConn implements the grpcstats.Handler interface
+func (statsHandler) HandleConn(context.Context, grpcstats.ConnStats) {
+ // noop
+}
+
+// TagRPC implements the grpcstats.Handler interface
+func (statsHandler) TagRPC(ctx context.Context, _ *grpcstats.RPCTagInfo) context.Context {
+ // noop
+ return ctx
+}
+
+// HandleRPC implements the grpcstats.Handler interface
+func (h statsHandler) HandleRPC(ctx context.Context, stat grpcstats.RPCStats) {
+ state := h.vu.State()
+ tags := getTags(ctx)
+ switch s := stat.(type) {
+ case *grpcstats.OutHeader:
+ if state.Options.SystemTags.Has(metrics.TagIP) && s.RemoteAddr != nil {
+ if ip, _, err := net.SplitHostPort(s.RemoteAddr.String()); err == nil {
+ tags["ip"] = ip
+ }
+ }
+ case *grpcstats.End:
+ if state.Options.SystemTags.Has(metrics.TagStatus) {
+ tags["status"] = strconv.Itoa(int(status.Code(s.Error)))
+ }
+
+ mTags := map[string]string(tags)
+ sampleTags := metrics.IntoSampleTags(&mTags)
+ metrics.PushIfNotDone(ctx, state.Samples, metrics.ConnectedSamples{
+ Samples: []metrics.Sample{
+ {
+ Metric: state.BuiltinMetrics.GRPCReqDuration,
+ Tags: sampleTags,
+ Value: metrics.D(s.EndTime.Sub(s.BeginTime)),
+ Time: s.EndTime,
+ },
+ },
+ })
+ }
+
+ // (rogchap) Re-using --http-debug flag as gRPC is technically still HTTP
+ if state.Options.HTTPDebug.String != "" {
+ logger := state.Logger.WithField("source", "http-debug")
+ httpDebugOption := state.Options.HTTPDebug.String
+ DebugStat(logger, stat, httpDebugOption)
+ }
+}
+
+// DebugStat prints debugging information based on RPCStats.
+func DebugStat(logger logrus.FieldLogger, stat grpcstats.RPCStats, httpDebugOption string) {
+ switch s := stat.(type) {
+ case *grpcstats.OutHeader:
+ logger.Infof("Out Header:\nFull Method: %s\nRemote Address: %s\n%s\n",
+ s.FullMethod, s.RemoteAddr, formatMetadata(s.Header))
+ case *grpcstats.OutTrailer:
+ if len(s.Trailer) > 0 {
+ logger.Infof("Out Trailer:\n%s\n", formatMetadata(s.Trailer))
+ }
+ case *grpcstats.OutPayload:
+ if httpDebugOption == "full" {
+ logger.Infof("Out Payload:\nWire Length: %d\nSent Time: %s\n%s\n\n",
+ s.WireLength, s.SentTime, formatPayload(s.Payload))
+ }
+ case *grpcstats.InHeader:
+ if len(s.Header) > 0 {
+ logger.Infof("In Header:\nWire Length: %d\n%s\n", s.WireLength, formatMetadata(s.Header))
+ }
+ case *grpcstats.InTrailer:
+ if len(s.Trailer) > 0 {
+ logger.Infof("In Trailer:\nWire Length: %d\n%s\n", s.WireLength, formatMetadata(s.Trailer))
+ }
+ case *grpcstats.InPayload:
+ if httpDebugOption == "full" {
+ logger.Infof("In Payload:\nWire Length: %d\nReceived Time: %s\n%s\n\n",
+ s.WireLength, s.RecvTime, formatPayload(s.Payload))
+ }
+ }
+}
+
+func formatMetadata(md metadata.MD) string {
+ var sb strings.Builder
+ for k, v := range md {
+ sb.WriteString(k)
+ sb.WriteString(": ")
+ sb.WriteString(strings.Join(v, ", "))
+ sb.WriteRune('\n')
+ }
+
+ return sb.String()
+}
+
+func formatPayload(payload interface{}) string {
+ msg, ok := payload.(proto.Message)
+ if !ok {
+ // check to see if we are dealing with a APIv1 message
+ msgV1, ok := payload.(protov1.Message)
+ if !ok {
+ return ""
+ }
+ msg = protov1.MessageV2(msgV1)
+ }
+
+ marshaler := prototext.MarshalOptions{
+ Multiline: true,
+ Indent: " ",
+ }
+ b, err := marshaler.Marshal(msg)
+ if err != nil {
+ return ""
+ }
+ return string(b)
+}
+
+type ctxKeyTags struct{}
+
+type reqtags map[string]string
+
+func withTags(ctx context.Context, tags reqtags) context.Context {
+ if tags == nil {
+ tags = make(map[string]string)
+ }
+ return context.WithValue(ctx, ctxKeyTags{}, tags)
+}
+
+func getTags(ctx context.Context) reqtags {
+ v := ctx.Value(ctxKeyTags{})
+ if v == nil {
+ return make(map[string]string)
+ }
+ return v.(reqtags)
+}
diff --git a/lib/netext/grpcext/conn_test.go b/lib/netext/grpcext/conn_test.go
new file mode 100644
index 000000000000..93985ee8fcff
--- /dev/null
+++ b/lib/netext/grpcext/conn_test.go
@@ -0,0 +1,300 @@
+package grpcext
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "sync/atomic"
+ "testing"
+
+ "github.com/jhump/protoreflect/desc/protoparse"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+ reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
+ "google.golang.org/protobuf/encoding/protojson"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protodesc"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/types/descriptorpb"
+ "google.golang.org/protobuf/types/dynamicpb"
+)
+
+func TestInvoke(t *testing.T) {
+ t.Parallel()
+
+ helloReply := func(in, out *dynamicpb.Message) error {
+ err := protojson.Unmarshal([]byte(`{"reply":"text reply"}`), out)
+ require.NoError(t, err)
+
+ return nil
+ }
+
+ c := Conn{raw: invokemock(helloReply)}
+ r := Request{
+ MethodDescriptor: methodFromProto("./testdata/hello.proto", "SayHello"),
+ Message: []byte(`{"greeting":"text request"}`),
+ }
+ res, err := c.Invoke(context.Background(), "/hello.HelloService/SayHello", metadata.New(nil), r)
+ require.NoError(t, err)
+
+ assert.Equal(t, codes.OK, res.Status)
+ assert.Equal(t, map[string]interface{}{"reply": "text reply"}, res.Message)
+ assert.Empty(t, res.Error)
+}
+
+func TestInvokeReturnError(t *testing.T) {
+ t.Parallel()
+
+ helloReply := func(in, out *dynamicpb.Message) error {
+ return fmt.Errorf("test error")
+ }
+
+ c := Conn{raw: invokemock(helloReply)}
+ r := Request{
+ MethodDescriptor: methodFromProto("./testdata/hello.proto", "SayHello"),
+ Message: []byte(`{"greeting":"text request"}`),
+ }
+ res, err := c.Invoke(context.Background(), "/hello.HelloService/SayHello", metadata.New(nil), r)
+ require.NoError(t, err)
+
+ assert.Equal(t, codes.Unknown, res.Status)
+ assert.NotEmpty(t, res.Error)
+ assert.Equal(t, map[string]interface{}{"reply": ""}, res.Message)
+}
+
+func TestConnInvokeInvalid(t *testing.T) {
+ t.Parallel()
+
+ var (
+ // valid arguments
+ ctx = context.Background()
+ method = "not-empty-method"
+ md = metadata.New(nil)
+ methodDesc = methodFromProto("./testdata/hello.proto", "SayHello")
+ payload = []byte(`{"greeting":"test"}`)
+ )
+
+ req := Request{
+ MethodDescriptor: methodDesc,
+ Message: payload,
+ }
+
+ tests := []struct {
+ name string
+ ctx context.Context
+ md metadata.MD
+ method string
+ req Request
+ experr string
+ }{
+ {
+ name: "EmptyMethod",
+ ctx: ctx,
+ method: "",
+ md: md,
+ req: req,
+ experr: "method is required",
+ },
+ {
+ name: "NullMethodDescriptor",
+ ctx: ctx,
+ method: method,
+ md: nil,
+ req: Request{Message: payload},
+ experr: "method descriptor is required",
+ },
+ {
+ name: "NullMessage",
+ ctx: ctx,
+ method: method,
+ md: nil,
+ req: Request{MethodDescriptor: methodDesc},
+ experr: "message is required",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := Conn{}
+ res, err := c.Invoke(tt.ctx, tt.method, tt.md, tt.req)
+ require.Error(t, err)
+ require.Nil(t, res)
+ assert.Contains(t, err.Error(), tt.experr)
+ })
+ }
+}
+
+func TestResolveFileDescriptors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ pkgs []string
+ services []string
+ expectedDescriptors int
+ }{
+ {
+ name: "SuccessSamePackage",
+ pkgs: []string{"mypkg"},
+ services: []string{"Service1", "Service2", "Service3"},
+ expectedDescriptors: 3,
+ },
+ {
+ name: "SuccessMultiPackages",
+ pkgs: []string{"mypkg1", "mypkg2", "mypkg3"},
+ services: []string{"Service", "Service", "Service"},
+ expectedDescriptors: 3,
+ },
+ {
+ name: "DeduplicateServices",
+ pkgs: []string{"mypkg1"},
+ services: []string{"Service1", "Service2", "Service1"},
+ expectedDescriptors: 2,
+ },
+ {
+ name: "NoServices",
+ services: []string{},
+ expectedDescriptors: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ var (
+ lsr = &reflectpb.ListServiceResponse{}
+ mock = &getServiceFileDescriptorMock{}
+ )
+ for i, service := range tt.services {
+ // if only one package is defined then
+ // the package is the same for every service
+ pkg := tt.pkgs[0]
+ if len(tt.pkgs) > 1 {
+ pkg = tt.pkgs[i]
+ }
+
+ lsr.Service = append(lsr.Service, &reflectpb.ServiceResponse{
+ Name: fmt.Sprintf("%s.%s", pkg, service),
+ })
+ mock.pkgs = append(mock.pkgs, pkg)
+ mock.names = append(mock.names, service)
+ }
+
+ rc := ReflectionClient{}
+ fdset, err := rc.resolveServiceFileDescriptors(mock, lsr)
+ require.NoError(t, err)
+ assert.Len(t, fdset.File, tt.expectedDescriptors)
+ })
+ }
+}
+
+type getServiceFileDescriptorMock struct {
+ pkgs []string
+ names []string
+ nreqs int64
+}
+
+func (m *getServiceFileDescriptorMock) Send(req *reflectpb.ServerReflectionRequest) error {
+ // TODO: check that the sent message is expected,
+ // otherwise return an error
+ return nil
+}
+
+func (m *getServiceFileDescriptorMock) Recv() (*reflectpb.ServerReflectionResponse, error) {
+ n := atomic.AddInt64(&m.nreqs, 1)
+ ptr := func(s string) (sptr *string) {
+ return &s
+ }
+ index := n - 1
+ fdp := &descriptorpb.FileDescriptorProto{
+ Package: ptr(m.pkgs[index]),
+ Name: ptr(m.names[index]),
+ }
+ b, err := proto.Marshal(fdp)
+ if err != nil {
+ return nil, err
+ }
+ srr := &reflectpb.ServerReflectionResponse{
+ MessageResponse: &reflectpb.ServerReflectionResponse_FileDescriptorResponse{
+ FileDescriptorResponse: &reflectpb.FileDescriptorResponse{
+ FileDescriptorProto: [][]byte{b},
+ },
+ },
+ }
+ return srr, nil
+}
+
+func methodFromProto(proto string, method string) protoreflect.MethodDescriptor {
+ parser := protoparse.Parser{
+ InferImportPaths: false,
+ Accessor: protoparse.FileAccessor(func(filename string) (io.ReadCloser, error) {
+ b := `
+syntax = "proto3";
+
+package hello;
+
+service HelloService {
+ rpc SayHello(HelloRequest) returns (HelloResponse);
+ rpc LotsOfReplies(HelloRequest) returns (stream HelloResponse);
+ rpc LotsOfGreetings(stream HelloRequest) returns (HelloResponse);
+ rpc BidiHello(stream HelloRequest) returns (stream HelloResponse);
+}
+
+message HelloRequest {
+ string greeting = 1;
+}
+
+message HelloResponse {
+ string reply = 1;
+}`
+ return io.NopCloser(bytes.NewBufferString(b)), nil
+ }),
+ }
+
+ fds, err := parser.ParseFiles(proto)
+ if err != nil {
+ panic(err)
+ }
+
+ fd, err := protodesc.NewFile(fds[0].AsFileDescriptorProto(), nil)
+ if err != nil {
+ panic(err)
+ }
+
+ services := fd.Services()
+ if services.Len() == 0 {
+ panic("no available services")
+ }
+ return services.Get(0).Methods().ByName(protoreflect.Name(method))
+}
+
+// invokemock is a mock for the grpc connection supporting only unary requests.
+type invokemock func(in, out *dynamicpb.Message) error
+
+func (im invokemock) Invoke(ctx context.Context, method string, payload interface{}, reply interface{}, opts ...grpc.CallOption) error {
+ in, ok := payload.(*dynamicpb.Message)
+ if !ok {
+ return fmt.Errorf("unexpected type for payload")
+ }
+ out, ok := reply.(*dynamicpb.Message)
+ if !ok {
+ return fmt.Errorf("unexpected type for reply")
+ }
+ return im(in, out)
+}
+
+func (invokemock) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
+ panic("not implemented")
+}
+
+func (invokemock) Close() error {
+ return nil
+}
diff --git a/lib/netext/grpcext/reflect.go b/lib/netext/grpcext/reflect.go
new file mode 100644
index 000000000000..92f7b98eae6a
--- /dev/null
+++ b/lib/netext/grpcext/reflect.go
@@ -0,0 +1,117 @@
+package grpcext
+
+import (
+ "context"
+ "fmt"
+
+ "google.golang.org/grpc"
+ reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/descriptorpb"
+)
+
+// ReflectionClient wraps a grpc.ServerReflectionClient.
+type ReflectionClient struct {
+ Conn grpc.ClientConnInterface
+}
+
+// Reflect will use the grpc reflection api to make the file descriptors available to request.
+// It is called in the connect function the first time the Client.Connect function is called.
+func (rc *ReflectionClient) Reflect(ctx context.Context) (*descriptorpb.FileDescriptorSet, error) {
+ client := reflectpb.NewServerReflectionClient(rc.Conn)
+ methodClient, err := client.ServerReflectionInfo(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("can't get server info: %w", err)
+ }
+ req := &reflectpb.ServerReflectionRequest{
+ MessageRequest: &reflectpb.ServerReflectionRequest_ListServices{},
+ }
+ resp, err := sendReceive(methodClient, req)
+ if err != nil {
+ return nil, fmt.Errorf("can't list services: %w", err)
+ }
+ listResp := resp.GetListServicesResponse()
+ if listResp == nil {
+ return nil, fmt.Errorf("can't list services, nil response")
+ }
+ fdset, err := rc.resolveServiceFileDescriptors(methodClient, listResp)
+ if err != nil {
+ return nil, fmt.Errorf("can't resolve services' file descriptors: %w", err)
+ }
+ return fdset, nil
+}
+
+func (rc *ReflectionClient) resolveServiceFileDescriptors(
+ client sendReceiver,
+ res *reflectpb.ListServiceResponse,
+) (*descriptorpb.FileDescriptorSet, error) {
+ services := res.GetService()
+ seen := make(map[fileDescriptorLookupKey]bool, len(services))
+ fdset := &descriptorpb.FileDescriptorSet{
+ File: make([]*descriptorpb.FileDescriptorProto, 0, len(services)),
+ }
+
+ for _, service := range services {
+ req := &reflectpb.ServerReflectionRequest{
+ MessageRequest: &reflectpb.ServerReflectionRequest_FileContainingSymbol{
+ FileContainingSymbol: service.GetName(),
+ },
+ }
+ resp, err := sendReceive(client, req)
+ if err != nil {
+ return nil, fmt.Errorf("can't get method on service %q: %w", service, err)
+ }
+ fdResp := resp.GetFileDescriptorResponse()
+ for _, raw := range fdResp.GetFileDescriptorProto() {
+ var fdp descriptorpb.FileDescriptorProto
+ if err = proto.Unmarshal(raw, &fdp); err != nil {
+ return nil, fmt.Errorf("can't unmarshal proto on service %q: %w", service, err)
+ }
+ fdkey := fileDescriptorLookupKey{
+ Package: *fdp.Package,
+ Name: *fdp.Name,
+ }
+ if seen[fdkey] {
+ // When a proto file contains declarations for multiple services
+ // then the same proto file is returned multiple times,
+ // this prevents adding the returned proto file as a duplicate.
+ continue
+ }
+ seen[fdkey] = true
+ fdset.File = append(fdset.File, &fdp)
+ }
+ }
+ return fdset, nil
+}
+
+// sendReceiver is a smaller interface for decoupling
+// from `reflectpb.ServerReflection_ServerReflectionInfoClient`,
+// that has the dependency from `grpc.ClientStream`,
+// which is too much in the case the requirement is to just make a reflection's request.
+// It makes the API more restricted and with a controlled surface,
+// in this way the testing should be easier also.
+type sendReceiver interface {
+ Send(*reflectpb.ServerReflectionRequest) error
+ Recv() (*reflectpb.ServerReflectionResponse, error)
+}
+
+// sendReceive sends a request to a reflection client and,
+// receives a response.
+func sendReceive(
+ client sendReceiver,
+ req *reflectpb.ServerReflectionRequest,
+) (*reflectpb.ServerReflectionResponse, error) {
+ if err := client.Send(req); err != nil {
+ return nil, fmt.Errorf("can't send request: %w", err)
+ }
+ resp, err := client.Recv()
+ if err != nil {
+ return nil, fmt.Errorf("can't receive response: %w", err)
+ }
+ return resp, nil
+}
+
+type fileDescriptorLookupKey struct {
+ Package string
+ Name string
+}