From 79a5679dcfe0faf9e234c6c38a123b4980977191 Mon Sep 17 00:00:00 2001 From: Alexander Morozov Date: Wed, 26 Oct 2016 12:43:33 -0700 Subject: [PATCH] raft: introducing transport package This package is separate grpc transport layer for raft package. Before we used membership package + one very big method in raft package. Signed-off-by: Alexander Morozov --- log/context.go | 15 + manager/controlapi/node_test.go | 13 +- manager/state/raft/membership/cluster.go | 219 ++------ manager/state/raft/membership/cluster_test.go | 4 +- manager/state/raft/raft.go | 526 ++++++++---------- manager/state/raft/raft_test.go | 37 +- manager/state/raft/storage.go | 52 +- manager/state/raft/storage_test.go | 6 +- manager/state/raft/testutils/testutils.go | 3 +- .../state/raft/transport/mock_raft_test.go | 168 ++++++ manager/state/raft/transport/peer.go | 299 ++++++++++ manager/state/raft/transport/transport.go | 381 +++++++++++++ .../state/raft/transport/transport_test.go | 286 ++++++++++ 13 files changed, 1467 insertions(+), 542 deletions(-) create mode 100644 manager/state/raft/transport/mock_raft_test.go create mode 100644 manager/state/raft/transport/peer.go create mode 100644 manager/state/raft/transport/transport.go create mode 100644 manager/state/raft/transport/transport_test.go diff --git a/log/context.go b/log/context.go index 3da380f112..ce7da930ba 100644 --- a/log/context.go +++ b/log/context.go @@ -29,6 +29,21 @@ func WithLogger(ctx context.Context, logger *logrus.Entry) context.Context { return context.WithValue(ctx, loggerKey{}, logger) } +// WithFields returns a new context with added fields to logger. +func WithFields(ctx context.Context, fields logrus.Fields) context.Context { + logger := ctx.Value(loggerKey{}) + + if logger == nil { + logger = L + } + return WithLogger(ctx, logger.(*logrus.Entry).WithFields(fields)) +} + +// WithField is convenience wrapper around WithFields. +func WithField(ctx context.Context, key, value string) context.Context { + return WithFields(ctx, logrus.Fields{key: value}) +} + // GetLogger retrieves the current logger from the context. If no logger is // available, the default logger is returned. func GetLogger(ctx context.Context) *logrus.Entry { diff --git a/manager/controlapi/node_test.go b/manager/controlapi/node_test.go index 84cdaf9b03..71c11ca902 100644 --- a/manager/controlapi/node_test.go +++ b/manager/controlapi/node_test.go @@ -373,9 +373,6 @@ func TestListManagerNodes(t *testing.T) { return nil })) - // Switch the raft node used by the server - ts.Server.raft = nodes[2].Node - // Stop node 1 (leader) nodes[1].Server.Stop() nodes[1].ShutdownRaft() @@ -390,6 +387,16 @@ func TestListManagerNodes(t *testing.T) { // Wait for the re-election to occur raftutils.WaitForCluster(t, clockSource, newCluster) + var leaderNode *raftutils.TestNode + for _, node := range newCluster { + if node.IsLeader() { + leaderNode = node + } + } + + // Switch the raft node used by the server + ts.Server.raft = leaderNode.Node + // Node 1 should not be the leader anymore assert.NoError(t, raftutils.PollFunc(clockSource, func() error { r, err = ts.Client.ListNodes(context.Background(), &api.ListNodesRequest{}) diff --git a/manager/state/raft/membership/cluster.go b/manager/state/raft/membership/cluster.go index 84c9514066..0bf69da151 100644 --- a/manager/state/raft/membership/cluster.go +++ b/manager/state/raft/membership/cluster.go @@ -2,16 +2,12 @@ package membership import ( "errors" - "fmt" "sync" - "google.golang.org/grpc" - "github.com/coreos/etcd/raft/raftpb" "github.com/docker/swarmkit/api" "github.com/docker/swarmkit/watch" "github.com/gogo/protobuf/proto" - "golang.org/x/net/context" ) var ( @@ -25,26 +21,19 @@ var ( ErrConfigChangeInvalid = errors.New("membership: ConfChange type should be either AddNode, RemoveNode or UpdateNode") // ErrCannotUnmarshalConfig is thrown when a node cannot unmarshal a configuration change ErrCannotUnmarshalConfig = errors.New("membership: cannot unmarshal configuration change") + // ErrMemberRemoved is thrown when a node was removed from the cluster + ErrMemberRemoved = errors.New("raft: member was removed from the cluster") ) -// deferredConn used to store removed members connection for some time. -// We need this in case if removed node is redirector or endpoint of ControlAPI call. -type deferredConn struct { - tick int - conn *grpc.ClientConn -} - // Cluster represents a set of active // raft Members type Cluster struct { - mu sync.RWMutex - members map[uint64]*Member - deferedConns map[*deferredConn]struct{} + mu sync.RWMutex + members map[uint64]*Member // removed contains the list of removed Members, // those ids cannot be reused - removed map[uint64]bool - heartbeatTicks int + removed map[uint64]bool PeersBroadcast *watch.Queue } @@ -52,74 +41,19 @@ type Cluster struct { // Member represents a raft Cluster Member type Member struct { *api.RaftMember - - Conn *grpc.ClientConn - tick int - active bool - lastSeenHost string -} - -// HealthCheck sends a health check RPC to the member and returns the response. -func (member *Member) HealthCheck(ctx context.Context) error { - healthClient := api.NewHealthClient(member.Conn) - resp, err := healthClient.Check(ctx, &api.HealthCheckRequest{Service: "Raft"}) - if err != nil { - return err - } - if resp.Status != api.HealthCheckResponse_SERVING { - return fmt.Errorf("health check returned status %s", resp.Status.String()) - } - return nil } // NewCluster creates a new Cluster neighbors list for a raft Member. -// Member marked as inactive if there was no call ReportActive for heartbeatInterval. -func NewCluster(heartbeatTicks int) *Cluster { +func NewCluster() *Cluster { // TODO(abronan): generate Cluster ID for federation return &Cluster{ members: make(map[uint64]*Member), removed: make(map[uint64]bool), - deferedConns: make(map[*deferredConn]struct{}), - heartbeatTicks: heartbeatTicks, PeersBroadcast: watch.NewQueue(), } } -func (c *Cluster) handleInactive() { - for _, m := range c.members { - if !m.active { - continue - } - m.tick++ - if m.tick > c.heartbeatTicks { - m.active = false - if m.Conn != nil { - m.Conn.Close() - } - } - } -} - -func (c *Cluster) handleDeferredConns() { - for dc := range c.deferedConns { - dc.tick++ - if dc.tick > c.heartbeatTicks { - dc.conn.Close() - delete(c.deferedConns, dc) - } - } -} - -// Tick increases ticks for all members. After heartbeatTicks node marked as -// inactive. -func (c *Cluster) Tick() { - c.mu.Lock() - defer c.mu.Unlock() - c.handleInactive() - c.handleDeferredConns() -} - // Members returns the list of raft Members in the Cluster. func (c *Cluster) Members() map[uint64]*Member { members := make(map[uint64]*Member) @@ -168,8 +102,6 @@ func (c *Cluster) AddMember(member *Member) error { if c.removed[member.RaftID] { return ErrIDRemoved } - member.active = true - member.tick = 0 c.members[member.RaftID] = member @@ -187,55 +119,47 @@ func (c *Cluster) RemoveMember(id uint64) error { return c.clearMember(id) } -// ClearMember removes a node from the Cluster Memberlist, but does NOT add it -// to the removed list. -func (c *Cluster) ClearMember(id uint64) error { +// UpdateMember updates member address. +func (c *Cluster) UpdateMember(id uint64, m *api.RaftMember) error { c.mu.Lock() defer c.mu.Unlock() - return c.clearMember(id) -} - -func (c *Cluster) clearMember(id uint64) error { - m, ok := c.members[id] - if ok { - if m.Conn != nil { - // defer connection close to after heartbeatTicks - dConn := &deferredConn{conn: m.Conn} - c.deferedConns[dConn] = struct{}{} - } - delete(c.members, id) + if c.removed[id] { + return ErrIDRemoved } - c.broadcastUpdate() - return nil -} - -// ReplaceMemberConnection replaces the member's GRPC connection. -func (c *Cluster) ReplaceMemberConnection(id uint64, oldConn *Member, newConn *Member, newAddr string, force bool) error { - c.mu.Lock() - defer c.mu.Unlock() oldMember, ok := c.members[id] if !ok { return ErrIDNotFound } - if !force && oldConn.Conn != oldMember.Conn { - // The connection was already replaced. Don't do it again. - newConn.Conn.Close() - return nil + if oldMember.NodeID != m.NodeID { + // Should never happen; this is a sanity check + return errors.New("node ID mismatch match on node update") } - if oldMember.Conn != nil { - oldMember.Conn.Close() + if oldMember.Addr == m.Addr { + // nothing to do + return nil } + oldMember.RaftMember = m + return nil +} + +// ClearMember removes a node from the Cluster Memberlist, but does NOT add it +// to the removed list. +func (c *Cluster) ClearMember(id uint64) error { + c.mu.Lock() + defer c.mu.Unlock() - newMember := *oldMember - newMember.RaftMember = oldMember.RaftMember.Copy() - newMember.RaftMember.Addr = newAddr - newMember.Conn = newConn.Conn - c.members[id] = &newMember + return c.clearMember(id) +} +func (c *Cluster) clearMember(id uint64) error { + if _, ok := c.members[id]; ok { + delete(c.members, id) + c.broadcastUpdate() + } return nil } @@ -249,60 +173,12 @@ func (c *Cluster) IsIDRemoved(id uint64) bool { // Clear resets the list of active Members and removed Members. func (c *Cluster) Clear() { c.mu.Lock() - for _, member := range c.members { - if member.Conn != nil { - member.Conn.Close() - } - } - - for dc := range c.deferedConns { - dc.conn.Close() - } c.members = make(map[uint64]*Member) c.removed = make(map[uint64]bool) - c.deferedConns = make(map[*deferredConn]struct{}) c.mu.Unlock() } -// ReportActive reports that member is active (called ProcessRaftMessage), -func (c *Cluster) ReportActive(id uint64, sourceHost string) { - c.mu.Lock() - defer c.mu.Unlock() - m, ok := c.members[id] - if !ok { - return - } - m.tick = 0 - m.active = true - if sourceHost != "" { - m.lastSeenHost = sourceHost - } -} - -// Active returns true if node is active. -func (c *Cluster) Active(id uint64) bool { - c.mu.RLock() - defer c.mu.RUnlock() - m, ok := c.members[id] - if !ok { - return false - } - return m.active -} - -// LastSeenHost returns the last observed source address that the specified -// member connected from. -func (c *Cluster) LastSeenHost(id uint64) string { - c.mu.RLock() - defer c.mu.RUnlock() - m, ok := c.members[id] - if ok { - return m.lastSeenHost - } - return "" -} - // ValidateConfigurationChange takes a proposed ConfChange and // ensures that it is valid. func (c *Cluster) ValidateConfigurationChange(cc raftpb.ConfChange) error { @@ -334,34 +210,3 @@ func (c *Cluster) ValidateConfigurationChange(cc raftpb.ConfChange) error { } return nil } - -// CanRemoveMember checks if removing a Member would not result in a loss -// of quorum, this check is needed before submitting a configuration change -// that might block or harm the Cluster on Member recovery -func (c *Cluster) CanRemoveMember(from uint64, id uint64) bool { - members := c.Members() - nreachable := 0 // reachable managers after removal - - for _, m := range members { - if m.RaftID == id { - continue - } - - // Local node from where the remove is issued - if m.RaftID == from { - nreachable++ - continue - } - - if c.Active(m.RaftID) { - nreachable++ - } - } - - nquorum := (len(members)-1)/2 + 1 - if nreachable < nquorum { - return false - } - - return true -} diff --git a/manager/state/raft/membership/cluster_test.go b/manager/state/raft/membership/cluster_test.go index ef519a66df..6580d7f637 100644 --- a/manager/state/raft/membership/cluster_test.go +++ b/manager/state/raft/membership/cluster_test.go @@ -42,7 +42,7 @@ func newTestMember(id uint64) *membership.Member { } func newTestCluster(members []*membership.Member, removed []*membership.Member) *membership.Cluster { - c := membership.NewCluster(3) + c := membership.NewCluster() for _, m := range members { c.AddMember(m) } @@ -79,7 +79,7 @@ func TestClusterMember(t *testing.T) { } func TestMembers(t *testing.T) { - cls := membership.NewCluster(1) + cls := membership.NewCluster() defer cls.Clear() cls.AddMember(&membership.Member{RaftMember: &api.RaftMember{RaftID: 1}}) cls.AddMember(&membership.Member{RaftMember: &api.RaftMember{RaftID: 5}}) diff --git a/manager/state/raft/raft.go b/manager/state/raft/raft.go index a19f07b6e0..c1b4d16709 100644 --- a/manager/state/raft/raft.go +++ b/manager/state/raft/raft.go @@ -27,6 +27,7 @@ import ( "github.com/docker/swarmkit/manager/raftselector" "github.com/docker/swarmkit/manager/state/raft/membership" "github.com/docker/swarmkit/manager/state/raft/storage" + "github.com/docker/swarmkit/manager/state/raft/transport" "github.com/docker/swarmkit/manager/state/store" "github.com/docker/swarmkit/watch" "github.com/gogo/protobuf/proto" @@ -51,8 +52,6 @@ var ( ErrRequestTooLarge = errors.New("raft: raft message is too large and can't be sent") // ErrCannotRemoveMember is thrown when we try to remove a member from the cluster but this would result in a loss of quorum ErrCannotRemoveMember = errors.New("raft: member cannot be removed, because removing it may result in loss of quorum") - // ErrMemberRemoved is thrown when a node was removed from the cluster - ErrMemberRemoved = errors.New("raft: member was removed from the cluster") // ErrNoClusterLeader is thrown when the cluster has no elected leader ErrNoClusterLeader = errors.New("raft: no elected cluster leader") // ErrMemberUnknown is sent in response to a message from an @@ -88,8 +87,9 @@ type EncryptionKeyRotator interface { // Node represents the Raft Node useful // configuration. type Node struct { - raftNode raft.Node - cluster *membership.Cluster + raftNode raft.Node + cluster *membership.Cluster + transport *transport.Transport raftStore *raft.MemoryStorage memoryStore *store.MemoryStore @@ -100,6 +100,7 @@ type Node struct { campaignWhenAble bool signalledLeadership uint32 isMember uint32 + bootstrapMembers []*api.RaftMember // waitProp waits for all the proposals to be terminated before // shutting down the node. @@ -113,9 +114,11 @@ type Node struct { ticker clock.Ticker doneCh chan struct{} // RemovedFromRaft notifies about node deletion from raft cluster - RemovedFromRaft chan struct{} - removeRaftFunc func() - cancelFunc func() + RemovedFromRaft chan struct{} + cancelFunc func() + // removeRaftCh notifies about node deletion from raft cluster + removeRaftCh chan struct{} + removeRaftOnce sync.Once leadershipBroadcast *watch.Queue // used to coordinate shutdown @@ -131,7 +134,6 @@ type Node struct { // to stop. stopped chan struct{} - lastSendToMember map[uint64]chan struct{} raftLogger *storage.EncryptedRaftLogger keyRotator EncryptionKeyRotator rotationQueued bool @@ -189,7 +191,7 @@ func NewNode(opts NodeOptions) *Node { raftStore := raft.NewMemoryStorage() n := &Node{ - cluster: membership.NewCluster(2 * cfg.ElectionTick), + cluster: membership.NewCluster(), raftStore: raftStore, opts: opts, Config: &raft.Config{ @@ -204,7 +206,6 @@ func NewNode(opts NodeOptions) *Node { RemovedFromRaft: make(chan struct{}), stopped: make(chan struct{}), leadershipBroadcast: watch.NewQueue(), - lastSendToMember: make(map[uint64]chan struct{}), keyRotator: opts.KeyRotator, } n.memoryStore = store.NewMemoryStore(n) @@ -218,16 +219,6 @@ func NewNode(opts NodeOptions) *Node { n.reqIDGen = idutil.NewGenerator(uint16(n.Config.ID), time.Now()) n.wait = newWait() - n.removeRaftFunc = func(n *Node) func() { - var removeRaftOnce sync.Once - return func() { - removeRaftOnce.Do(func() { - atomic.StoreUint32(&n.isMember, 0) - close(n.RemovedFromRaft) - }) - } - }(n) - n.cancelFunc = func(n *Node) func() { var cancelOnce sync.Once return func() { @@ -240,6 +231,34 @@ func NewNode(opts NodeOptions) *Node { return n } +// IsIDRemoved reports if member with id was removed from cluster. +// Part of transport.Raft interface. +func (n *Node) IsIDRemoved(id uint64) bool { + return n.cluster.IsIDRemoved(id) +} + +// NodeRemoved signals that node was removed from cluster and should stop. +// Part of transport.Raft interface. +func (n *Node) NodeRemoved() { + n.removeRaftOnce.Do(func() { + atomic.StoreUint32(&n.isMember, 0) + close(n.RemovedFromRaft) + }) +} + +// ReportSnapshot reports snapshot status to underlying raft node. +// Part of transport.Raft interface. +func (n *Node) ReportSnapshot(id uint64, status raft.SnapshotStatus) { + n.raftNode.ReportSnapshot(id, status) +} + +// ReportUnreachable reports to underlying raft node that member with id is +// unreachable. +// Part of transport.Raft interface. +func (n *Node) ReportUnreachable(id uint64) { + n.raftNode.ReportUnreachable(id) +} + // WithContext returns context which is cancelled when parent context cancelled // or node is stopped. func (n *Node) WithContext(ctx context.Context) (context.Context, context.CancelFunc) { @@ -255,12 +274,26 @@ func (n *Node) WithContext(ctx context.Context) (context.Context, context.Cancel return ctx, cancel } +func (n *Node) initTransport() { + transportConfig := &transport.Config{ + HeartbeatInterval: time.Duration(n.Config.ElectionTick) * n.opts.TickInterval, + SendTimeout: n.opts.SendTimeout, + Credentials: n.opts.TLSCredentials, + Raft: n, + } + n.transport = transport.New(transportConfig) +} + // JoinAndStart joins and starts the raft server func (n *Node) JoinAndStart(ctx context.Context) (err error) { ctx, cancel := n.WithContext(ctx) defer func() { cancel() if err != nil { + n.stopMu.Lock() + // to shutdown transport + close(n.stopped) + n.stopMu.Unlock() n.done() } }() @@ -283,14 +316,12 @@ func (n *Node) JoinAndStart(ctx context.Context) (err error) { if loadAndStartErr == storage.ErrNoWAL { if n.opts.JoinAddr != "" { - c, err := n.ConnectToMember(n.opts.JoinAddr, 10*time.Second) + conn, err := dial(n.opts.JoinAddr, "tcp", n.opts.TLSCredentials, 10*time.Second) if err != nil { return err } - client := api.NewRaftMembershipClient(c.Conn) - defer func() { - _ = c.Conn.Close() - }() + defer conn.Close() + client := api.NewRaftMembershipClient(conn) joinCtx, joinCancel := context.WithTimeout(ctx, 10*time.Second) defer joinCancel() @@ -307,12 +338,10 @@ func (n *Node) JoinAndStart(ctx context.Context) (err error) { return err } + n.initTransport() n.raftNode = raft.StartNode(n.Config, []raft.Peer{}) - if err := n.registerNodes(resp.Members); err != nil { - n.raftLogger.Close(ctx) - return err - } + n.bootstrapMembers = resp.Members } else { // First member in the cluster, self-assign ID n.Config.ID = uint64(rand.Int63()) + 1 @@ -320,6 +349,7 @@ func (n *Node) JoinAndStart(ctx context.Context) (err error) { if err != nil { return err } + n.initTransport() n.raftNode = raft.StartNode(n.Config, []raft.Peer{peer}) n.campaignWhenAble = true } @@ -331,6 +361,7 @@ func (n *Node) JoinAndStart(ctx context.Context) (err error) { log.G(ctx).Warning("ignoring request to join cluster, because raft state already exists") } n.campaignWhenAble = true + n.initTransport() n.raftNode = raft.RestartNode(n.Config) atomic.StoreUint32(&n.isMember, 1) return nil @@ -372,6 +403,9 @@ func (n *Node) done() { n.leadershipBroadcast.Close() n.cluster.PeersBroadcast.Close() n.memoryStore.Close() + if n.transport != nil { + n.transport.Stop() + } close(n.doneCh) } @@ -391,6 +425,12 @@ func (n *Node) Run(ctx context.Context) error { ctx = log.WithLogger(ctx, logrus.WithField("raft_id", fmt.Sprintf("%x", n.Config.ID))) ctx, cancel := context.WithCancel(ctx) + for _, node := range n.bootstrapMembers { + if err := n.registerNode(node); err != nil { + log.G(ctx).WithError(err).Errorf("failed to register member %x", node.RaftID) + } + } + defer func() { cancel() n.stop(ctx) @@ -414,7 +454,6 @@ func (n *Node) Run(ctx context.Context) error { select { case <-n.ticker.C(): n.raftNode.Tick() - n.cluster.Tick() case rd := <-n.raftNode.Ready(): raftConfig := n.getCurrentRaftConfig() @@ -423,10 +462,10 @@ func (n *Node) Run(ctx context.Context) error { return errors.Wrap(err, "failed to save entries to storage") } - if len(rd.Messages) != 0 { + for _, msg := range rd.Messages { // Send raft messages to peers - if err := n.send(ctx, rd.Messages); err != nil { - log.G(ctx).WithError(err).Error("failed to send message to members") + if err := n.transport.Send(msg); err != nil { + log.G(ctx).WithError(err).Error("failed to send message to member") } } @@ -435,8 +474,8 @@ func (n *Node) Run(ctx context.Context) error { // saveToStorage. if !raft.IsEmptySnap(rd.Snapshot) { // Load the snapshot data into the store - if err := n.restoreFromSnapshot(rd.Snapshot.Data, false); err != nil { - log.G(ctx).WithError(err).Error("failed to restore from snapshot") + if err := n.restoreFromSnapshot(ctx, rd.Snapshot.Data); err != nil { + log.G(ctx).WithError(err).Error("failed to restore cluster from snapshot") } n.appliedIndex = rd.Snapshot.Metadata.Index n.snapshotMeta = rd.Snapshot.Metadata @@ -555,6 +594,40 @@ func (n *Node) Run(ctx context.Context) error { } } +func (n *Node) restoreFromSnapshot(ctx context.Context, data []byte) error { + snapCluster, err := n.clusterSnapshot(data) + if err != nil { + return err + } + + oldMembers := n.cluster.Members() + + for _, member := range snapCluster.Members { + delete(oldMembers, member.RaftID) + } + + for _, removedMember := range snapCluster.Removed { + n.cluster.RemoveMember(removedMember) + if err := n.transport.RemovePeer(removedMember); err != nil { + log.G(ctx).WithError(err).Errorf("failed to remove peer %x from transport", removedMember) + } + delete(oldMembers, removedMember) + } + + for id, member := range oldMembers { + n.cluster.ClearMember(id) + if err := n.transport.RemovePeer(member.RaftID); err != nil { + log.G(ctx).WithError(err).Errorf("failed to remove peer %x from transport", member.RaftID) + } + } + for _, node := range snapCluster.Members { + if err := n.registerNode(&api.RaftMember{RaftID: node.RaftID, NodeID: node.NodeID, Addr: node.Addr}); err != nil { + log.G(ctx).WithError(err).Error("failed to register node from snapshot") + } + } + return nil +} + func (n *Node) needsSnapshot(ctx context.Context) bool { if n.waitForAppliedIndex == 0 && n.keyRotator.NeedsRotation() { keys := n.keyRotator.GetKeys() @@ -798,22 +871,27 @@ func (n *Node) Join(ctx context.Context, req *api.JoinRequest) (*api.JoinRespons // checkHealth tries to contact an aspiring member through its advertised address // and checks if its raft server is running. func (n *Node) checkHealth(ctx context.Context, addr string, timeout time.Duration) error { - conn, err := n.ConnectToMember(addr, timeout) + conn, err := dial(addr, "tcp", n.opts.TLSCredentials, timeout) if err != nil { return err } + defer conn.Close() + if timeout != 0 { tctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ctx = tctx } - defer conn.Conn.Close() - - if err := conn.HealthCheck(ctx); err != nil { + healthClient := api.NewHealthClient(conn) + resp, err := healthClient.Check(ctx, &api.HealthCheckRequest{Service: "Raft"}) + if err != nil { return errors.Wrap(err, "could not connect to prospective new cluster member using its advertised address") } + if resp.Status != api.HealthCheckResponse_SERVING { + return fmt.Errorf("health check returned status %s", resp.Status.String()) + } return nil } @@ -841,11 +919,15 @@ func (n *Node) addMember(ctx context.Context, addr string, raftID uint64, nodeID return n.configure(ctx, cc) } -// updateMember submits a configuration change to change a member's address. -func (n *Node) updateMember(ctx context.Context, addr string, raftID uint64, nodeID string) error { +// updateNodeBlocking runs synchronous job to update node address in whole cluster. +func (n *Node) updateNodeBlocking(ctx context.Context, id uint64, addr string) error { + m := n.cluster.GetMember(id) + if m == nil { + return errors.Errorf("member %x is not found for update", id) + } node := api.RaftMember{ - RaftID: raftID, - NodeID: nodeID, + RaftID: m.RaftID, + NodeID: m.NodeID, Addr: addr, } @@ -856,7 +938,7 @@ func (n *Node) updateMember(ctx context.Context, addr string, raftID uint64, nod cc := raftpb.ConfChange{ Type: raftpb.ConfChangeUpdateNode, - NodeID: raftID, + NodeID: id, Context: meta, } @@ -864,6 +946,18 @@ func (n *Node) updateMember(ctx context.Context, addr string, raftID uint64, nod return n.configure(ctx, cc) } +// UpdateNode submits a configuration change to change a member's address. +func (n *Node) UpdateNode(id uint64, addr string) { + ctx, cancel := n.WithContext(context.Background()) + defer cancel() + // spawn updating info in raft in background to unblock transport + go func() { + if err := n.updateNodeBlocking(ctx, id, addr); err != nil { + log.G(ctx).WithFields(logrus.Fields{"raft_id": n.Config.ID, "update_id": id}).WithError(err).Error("failed to update member address in cluster") + } + }() +} + // Leave asks to a member of the raft to remove // us from the raft cluster. This method is called // from a member who is willing to leave its raft @@ -897,7 +991,31 @@ func (n *Node) Leave(ctx context.Context, req *api.LeaveRequest) (*api.LeaveResp // CanRemoveMember checks if a member can be removed from // the context of the current node. func (n *Node) CanRemoveMember(id uint64) bool { - return n.cluster.CanRemoveMember(n.Config.ID, id) + members := n.cluster.Members() + nreachable := 0 // reachable managers after removal + + for _, m := range members { + if m.RaftID == id { + continue + } + + // Local node from where the remove is issued + if m.RaftID == n.Config.ID { + nreachable++ + continue + } + + if n.transport.Active(m.RaftID) { + nreachable++ + } + } + + nquorum := (len(members)-1)/2 + 1 + if nreachable < nquorum { + return false + } + + return true } func (n *Node) removeMember(ctx context.Context, id uint64) error { @@ -915,7 +1033,7 @@ func (n *Node) removeMember(ctx context.Context, id uint64) error { n.membershipLock.Lock() defer n.membershipLock.Unlock() - if n.cluster.CanRemoveMember(n.Config.ID, id) { + if n.CanRemoveMember(id) { cc := raftpb.ConfChange{ ID: id, Type: raftpb.ConfChangeRemoveNode, @@ -956,6 +1074,34 @@ func (n *Node) processRaftMessageLogger(ctx context.Context, msg *api.ProcessRaf return log.G(ctx).WithFields(fields) } +func (n *Node) reportNewAddress(ctx context.Context, id uint64) error { + // too early + if !n.IsMember() { + return nil + } + p, ok := peer.FromContext(ctx) + if !ok { + return nil + } + oldAddr, err := n.transport.PeerAddr(id) + if err != nil { + return err + } + newHost, _, err := net.SplitHostPort(p.Addr.String()) + if err != nil { + return err + } + _, officialPort, err := net.SplitHostPort(oldAddr) + if err != nil { + return err + } + newAddr := net.JoinHostPort(newHost, officialPort) + if err := n.transport.UpdatePeerAddr(id, newAddr); err != nil { + return err + } + return nil +} + // ProcessRaftMessage calls 'Step' which advances the // raft state machine with the provided message on the // receiving node @@ -969,32 +1115,25 @@ func (n *Node) ProcessRaftMessage(ctx context.Context, msg *api.ProcessRaftMessa // a node in the remove set if n.cluster.IsIDRemoved(msg.Message.From) { n.processRaftMessageLogger(ctx, msg).Debug("received message from removed member") - return nil, grpc.Errorf(codes.NotFound, "%s", ErrMemberRemoved.Error()) - } - - var sourceHost string - peer, ok := peer.FromContext(ctx) - if ok { - sourceHost, _, _ = net.SplitHostPort(peer.Addr.String()) + return nil, grpc.Errorf(codes.NotFound, "%s", membership.ErrMemberRemoved.Error()) } - n.cluster.ReportActive(msg.Message.From, sourceHost) - ctx, cancel := n.WithContext(ctx) defer cancel() + if err := n.reportNewAddress(ctx, msg.Message.From); err != nil { + log.G(ctx).WithError(err).Errorf("failed to report new address of %x to transport", msg.Message.From) + } + // Reject vote requests from unreachable peers if msg.Message.Type == raftpb.MsgVote { member := n.cluster.GetMember(msg.Message.From) - if member == nil || member.Conn == nil { + if member == nil { n.processRaftMessageLogger(ctx, msg).Debug("received message from unknown member") return &api.ProcessRaftMessageResponse{}, nil } - healthCtx, cancel := context.WithTimeout(ctx, time.Duration(n.Config.ElectionTick)*n.opts.TickInterval) - defer cancel() - - if err := member.HealthCheck(healthCtx); err != nil { + if err := n.transport.HealthCheck(ctx, msg.Message.From); err != nil { n.processRaftMessageLogger(ctx, msg).WithError(err).Debug("member which sent vote request failed health check") return &api.ProcessRaftMessageResponse{}, nil } @@ -1064,17 +1203,11 @@ func (n *Node) getLeaderConn() (*grpc.ClientConn, error) { if leader == n.Config.ID { return nil, raftselector.ErrIsLeader } - l := n.cluster.GetMember(leader) - if l == nil { - return nil, errors.New("no leader found") - } - if !n.cluster.Active(leader) { - return nil, errors.New("leader marked as inactive") - } - if l.Conn == nil { - return nil, errors.New("no connection to leader in member list") + conn, err := n.transport.PeerConn(leader) + if err != nil { + return nil, errors.Wrap(err, "failed to get connection to leader") } - return l.Conn, nil + return conn, nil } // LeaderConn returns current connection to cluster leader or raftselector.ErrIsLeader @@ -1122,8 +1255,12 @@ func (n *Node) registerNode(node *api.RaftMember) error { // and are adding ourself now with the remotely-reachable // address. if existingMember.Addr != node.Addr { + if node.RaftID != n.Config.ID { + if err := n.transport.UpdatePeer(node.RaftID, node.Addr); err != nil { + return err + } + } member.RaftMember = node - member.Conn = existingMember.Conn n.cluster.AddMember(member) } @@ -1132,11 +1269,7 @@ func (n *Node) registerNode(node *api.RaftMember) error { // Avoid opening a connection to the local node if node.RaftID != n.Config.ID { - // We don't want to impose a timeout on the grpc connection. It - // should keep retrying as long as necessary, in case the peer - // is temporarily unavailable. - var err error - if member, err = n.ConnectToMember(node.Addr, 0); err != nil { + if err := n.transport.AddPeer(node.RaftID, node.Addr); err != nil { return err } } @@ -1144,8 +1277,8 @@ func (n *Node) registerNode(node *api.RaftMember) error { member.RaftMember = node err := n.cluster.AddMember(member) if err != nil { - if member.Conn != nil { - _ = member.Conn.Close() + if rerr := n.transport.RemovePeer(node.RaftID); rerr != nil { + return errors.Wrapf(rerr, "failed to remove peer after error %v", err) } return err } @@ -1153,17 +1286,6 @@ func (n *Node) registerNode(node *api.RaftMember) error { return nil } -// registerNodes registers a set of nodes in the cluster -func (n *Node) registerNodes(nodes []*api.RaftMember) error { - for _, node := range nodes { - if err := n.registerNode(node); err != nil { - return err - } - } - - return nil -} - // ProposeValue calls Propose on the raft and waits // on the commit log action before returning a result func (n *Node) ProposeValue(ctx context.Context, storeAction []*api.StoreAction, cb func()) error { @@ -1209,7 +1331,7 @@ func (n *Node) GetMemberlist() map[uint64]*api.RaftMember { leader := false if member.RaftID != n.Config.ID { - if !n.cluster.Active(member.RaftID) { + if !n.transport.Active(member.RaftID) { reachability = api.RaftMemberStatus_UNREACHABLE } } @@ -1294,183 +1416,6 @@ func (n *Node) saveToStorage( return nil } -// Sends a series of messages to members in the raft -func (n *Node) send(ctx context.Context, messages []raftpb.Message) error { - members := n.cluster.Members() - - n.stopMu.RLock() - defer n.stopMu.RUnlock() - - for _, m := range messages { - // Process locally - if m.To == n.Config.ID { - if err := n.raftNode.Step(ctx, m); err != nil { - return err - } - continue - } - - if m.Type == raftpb.MsgProp { - // We don't forward proposals to the leader. Our - // current architecture depends on only the leader - // making proposals, so in-flight proposals can be - // guaranteed not to conflict. - continue - } - - ch := make(chan struct{}) - - n.asyncTasks.Add(1) - go n.sendToMember(ctx, members, m, n.lastSendToMember[m.To], ch) - - n.lastSendToMember[m.To] = ch - } - - return nil -} - -func (n *Node) sendToMember(ctx context.Context, members map[uint64]*membership.Member, m raftpb.Message, lastSend <-chan struct{}, thisSend chan<- struct{}) { - defer n.asyncTasks.Done() - defer close(thisSend) - - if lastSend != nil { - waitCtx, waitCancel := context.WithTimeout(ctx, n.opts.SendTimeout) - defer waitCancel() - - select { - case <-lastSend: - case <-waitCtx.Done(): - return - } - - select { - case <-waitCtx.Done(): - return - default: - } - } - - ctx, cancel := context.WithTimeout(ctx, n.opts.SendTimeout) - defer cancel() - - if n.cluster.IsIDRemoved(m.To) { - // Should not send to removed members - return - } - - var conn *membership.Member - if toMember, ok := members[m.To]; ok { - conn = toMember - } else { - // If we are being asked to send to a member that's not in - // our member list, that could indicate that the current leader - // was added while we were offline. Try to resolve its address. - log.G(ctx).Warningf("sending message to an unrecognized member ID %x", m.To) - - // Choose a random member - var ( - queryMember *membership.Member - id uint64 - ) - for id, queryMember = range members { - if id != n.Config.ID { - break - } - } - - if queryMember == nil || queryMember.RaftID == n.Config.ID { - log.G(ctx).Error("could not find cluster member to query for leader address") - return - } - - resp, err := api.NewRaftClient(queryMember.Conn).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: m.To}) - if err != nil { - log.G(ctx).WithError(err).Errorf("could not resolve address of member ID %x", m.To) - return - } - conn, err = n.ConnectToMember(resp.Addr, n.opts.SendTimeout) - if err != nil { - log.G(ctx).WithError(err).Errorf("could connect to member ID %x at %s", m.To, resp.Addr) - return - } - // The temporary connection is only used for this message. - // Eventually, we should catch up and add a long-lived - // connection to the member list. - defer conn.Conn.Close() - } - - _, err := api.NewRaftClient(conn.Conn).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m}) - if err != nil { - if grpc.Code(err) == codes.NotFound && grpc.ErrorDesc(err) == ErrMemberRemoved.Error() { - n.removeRaftFunc() - } - if m.Type == raftpb.MsgSnap { - n.raftNode.ReportSnapshot(m.To, raft.SnapshotFailure) - } - if !n.IsMember() { - // node is removed from cluster or stopped - return - } - n.raftNode.ReportUnreachable(m.To) - - lastSeenHost := n.cluster.LastSeenHost(m.To) - if lastSeenHost != "" { - // Check if address has changed - officialHost, officialPort, _ := net.SplitHostPort(conn.Addr) - if officialHost != lastSeenHost { - reconnectAddr := net.JoinHostPort(lastSeenHost, officialPort) - log.G(ctx).Warningf("detected address change for %x (%s -> %s)", m.To, conn.Addr, reconnectAddr) - if err := n.handleAddressChange(ctx, conn, reconnectAddr); err != nil { - log.G(ctx).WithError(err).Error("failed to hande address change") - } - return - } - } - - // Bounce the connection - newConn, err := n.ConnectToMember(conn.Addr, 0) - if err != nil { - log.G(ctx).WithError(err).Errorf("could connect to member ID %x at %s", m.To, conn.Addr) - return - } - err = n.cluster.ReplaceMemberConnection(m.To, conn, newConn, conn.Addr, false) - if err != nil { - log.G(ctx).WithError(err).Error("failed to replace connection to raft member") - newConn.Conn.Close() - } - } else if m.Type == raftpb.MsgSnap { - n.raftNode.ReportSnapshot(m.To, raft.SnapshotFinish) - } -} - -func (n *Node) handleAddressChange(ctx context.Context, member *membership.Member, reconnectAddr string) error { - newConn, err := n.ConnectToMember(reconnectAddr, 0) - if err != nil { - return errors.Wrapf(err, "could connect to member ID %x at observed address %s", member.RaftID, reconnectAddr) - } - - healthCtx, cancelHealth := context.WithTimeout(ctx, time.Duration(n.Config.ElectionTick)*n.opts.TickInterval) - defer cancelHealth() - - if err := newConn.HealthCheck(healthCtx); err != nil { - return errors.Wrapf(err, "%x failed health check at observed address %s", member.RaftID, reconnectAddr) - } - - if err := n.cluster.ReplaceMemberConnection(member.RaftID, member, newConn, reconnectAddr, false); err != nil { - newConn.Conn.Close() - return errors.Wrap(err, "failed to replace connection to raft member") - } - - // If we're the leader, write the address change to raft - updateCtx, cancelUpdate := context.WithTimeout(ctx, time.Duration(n.Config.ElectionTick)*n.opts.TickInterval) - defer cancelUpdate() - if err := n.updateMember(updateCtx, reconnectAddr, member.RaftID, member.NodeID); err != nil { - return errors.Wrap(err, "failed to update member address in raft") - } - - return nil -} - // processInternalRaftRequest sends a message to nodes participating // in the raft to apply a log entry and then waits for it to be applied // on the server. It will block until the update is performed, there is @@ -1681,32 +1626,10 @@ func (n *Node) applyUpdateNode(ctx context.Context, cc raftpb.ConfChange) error return err } - oldMember := n.cluster.GetMember(newMember.RaftID) - - if oldMember == nil { - return ErrMemberUnknown - } - if oldMember.NodeID != newMember.NodeID { - // Should never happen; this is a sanity check - log.G(ctx).Errorf("node ID mismatch on node update (old: %x, new: %x)", oldMember.NodeID, newMember.NodeID) - return errors.New("node ID mismatch match on node update") - } - - if oldMember.Addr == newMember.Addr || oldMember.Conn == nil { - // nothing to do + if newMember.RaftID == n.Config.ID { return nil } - - newConn, err := n.ConnectToMember(newMember.Addr, 0) - if err != nil { - return errors.Errorf("could connect to member ID %x at %s: %v", newMember.RaftID, newMember.Addr, err) - } - if err := n.cluster.ReplaceMemberConnection(newMember.RaftID, oldMember, newConn, newMember.Addr, true); err != nil { - newConn.Conn.Close() - return err - } - - return nil + return n.cluster.UpdateMember(newMember.RaftID, newMember) } // applyRemoveNode is called when we receive a ConfChange @@ -1724,11 +1647,11 @@ func (n *Node) applyRemoveNode(ctx context.Context, cc raftpb.ConfChange) (err e } if cc.NodeID == n.Config.ID { + // wait the commit ack to be sent before closing connection n.asyncTasks.Wait() - n.removeRaftFunc() - + n.NodeRemoved() // if there are only 2 nodes in the cluster, and leader is leaving // before closing the connection, leader has to ensure that follower gets // noticed about this raft conf change commit. Otherwise, follower would @@ -1738,24 +1661,15 @@ func (n *Node) applyRemoveNode(ctx context.Context, cc raftpb.ConfChange) (err e // while n.asyncTasks.Wait() could be helpful in this case // it's the best-effort strategy, because this send could be fail due to some errors (such as time limit exceeds) // TODO(Runshen Zhu): use leadership transfer to solve this case, after vendoring raft 3.0+ + } else { + if err := n.transport.RemovePeer(cc.NodeID); err != nil { + return err + } } return n.cluster.RemoveMember(cc.NodeID) } -// ConnectToMember returns a member object with an initialized -// connection to communicate with other raft members -func (n *Node) ConnectToMember(addr string, timeout time.Duration) (*membership.Member, error) { - conn, err := dial(addr, "tcp", n.opts.TLSCredentials, timeout) - if err != nil { - return nil, err - } - - return &membership.Member{ - Conn: conn, - }, nil -} - // SubscribeLeadership returns channel to which events about leadership change // will be sent in form of raft.LeadershipState. Also cancel func is returned - // it should be called when listener is no longer interested in events. diff --git a/manager/state/raft/raft_test.go b/manager/state/raft/raft_test.go index d44cba1afe..303fde4791 100644 --- a/manager/state/raft/raft_test.go +++ b/manager/state/raft/raft_test.go @@ -63,6 +63,22 @@ func TestRaftBootstrap(t *testing.T) { assert.Len(t, nodes[3].GetMemberlist(), 3) } +func dial(n *raftutils.TestNode, addr string) (*grpc.ClientConn, error) { + grpcOptions := []grpc.DialOption{ + grpc.WithBackoffMaxDelay(2 * time.Second), + grpc.WithBlock(), + } + grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(n.SecurityConfig.ClientTLSCreds)) + + grpcOptions = append(grpcOptions, grpc.WithTimeout(10*time.Second)) + + cc, err := grpc.Dial(addr, grpcOptions...) + if err != nil { + return nil, err + } + return cc, nil +} + func TestRaftJoinTwice(t *testing.T) { t.Parallel() @@ -72,10 +88,10 @@ func TestRaftJoinTwice(t *testing.T) { // Node 3 tries to join again // Use gRPC instead of calling handler directly because of // authorization check. - client, err := nodes[3].ConnectToMember(nodes[1].Address, 10*time.Second) + cc, err := dial(nodes[3], nodes[1].Address) assert.NoError(t, err) - raftClient := api.NewRaftMembershipClient(client.Conn) - defer client.Conn.Close() + raftClient := api.NewRaftMembershipClient(cc) + defer cc.Close() ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) _, err = raftClient.Join(ctx, &api.JoinRequest{}) assert.Error(t, err, "expected error on duplicate Join") @@ -279,10 +295,10 @@ func TestRaftFollowerLeave(t *testing.T) { // Node 5 leaves the cluster // Use gRPC instead of calling handler directly because of // authorization check. - client, err := nodes[1].ConnectToMember(nodes[1].Address, 10*time.Second) + cc, err := dial(nodes[1], nodes[1].Address) assert.NoError(t, err) - raftClient := api.NewRaftMembershipClient(client.Conn) - defer client.Conn.Close() + raftClient := api.NewRaftMembershipClient(cc) + defer cc.Close() ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) resp, err := raftClient.Leave(ctx, &api.LeaveRequest{Node: &api.RaftMember{RaftID: nodes[5].Config.ID}}) assert.NoError(t, err, "error sending message to leave the raft") @@ -323,11 +339,12 @@ func TestRaftLeaderLeave(t *testing.T) { // Try to leave the raft // Use gRPC instead of calling handler directly because of // authorization check. - client, err := nodes[1].ConnectToMember(nodes[1].Address, 10*time.Second) + cc, err := dial(nodes[1], nodes[1].Address) assert.NoError(t, err) - defer client.Conn.Close() - raftClient := api.NewRaftMembershipClient(client.Conn) - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + raftClient := api.NewRaftMembershipClient(cc) + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() resp, err := raftClient.Leave(ctx, &api.LeaveRequest{Node: &api.RaftMember{RaftID: nodes[1].Config.ID}}) assert.NoError(t, err, "error sending message to leave the raft") assert.NotNil(t, resp, "leave response message is nil") diff --git a/manager/state/raft/storage.go b/manager/state/raft/storage.go index 402b04a33f..e56f624783 100644 --- a/manager/state/raft/storage.go +++ b/manager/state/raft/storage.go @@ -60,10 +60,26 @@ func (n *Node) loadAndStart(ctx context.Context, forceNewCluster bool) error { n.Config.ID = raftNode.RaftID if snapshot != nil { - // Load the snapshot data into the store - if err := n.restoreFromSnapshot(snapshot.Data, forceNewCluster); err != nil { + snapCluster, err := n.clusterSnapshot(snapshot.Data) + if err != nil { return err } + var bootstrapMembers []*api.RaftMember + if forceNewCluster { + for _, m := range snapCluster.Members { + if m.RaftID != n.Config.ID { + n.cluster.RemoveMember(m.RaftID) + continue + } + bootstrapMembers = append(bootstrapMembers, m) + } + } else { + bootstrapMembers = snapCluster.Members + } + n.bootstrapMembers = bootstrapMembers + for _, removedMember := range snapCluster.Removed { + n.cluster.RemoveMember(removedMember) + } } ents, st := waldata.Entries, waldata.HardState @@ -215,40 +231,18 @@ func (n *Node) doSnapshot(ctx context.Context, raftConfig api.RaftConfig) { <-viewStarted } -func (n *Node) restoreFromSnapshot(data []byte, forceNewCluster bool) error { +func (n *Node) clusterSnapshot(data []byte) (api.ClusterSnapshot, error) { var snapshot api.Snapshot if err := snapshot.Unmarshal(data); err != nil { - return err + return snapshot.Membership, err } if snapshot.Version != api.Snapshot_V0 { - return fmt.Errorf("unrecognized snapshot version %d", snapshot.Version) + return snapshot.Membership, fmt.Errorf("unrecognized snapshot version %d", snapshot.Version) } if err := n.memoryStore.Restore(&snapshot.Store); err != nil { - return err - } - - oldMembers := n.cluster.Members() - - for _, member := range snapshot.Membership.Members { - if forceNewCluster && member.RaftID != n.Config.ID { - n.cluster.RemoveMember(member.RaftID) - } else { - if err := n.registerNode(&api.RaftMember{RaftID: member.RaftID, NodeID: member.NodeID, Addr: member.Addr}); err != nil { - return err - } - } - delete(oldMembers, member.RaftID) - } - - for _, removedMember := range snapshot.Membership.Removed { - n.cluster.RemoveMember(removedMember) - delete(oldMembers, removedMember) - } - - for member := range oldMembers { - n.cluster.ClearMember(member) + return snapshot.Membership, err } - return nil + return snapshot.Membership, nil } diff --git a/manager/state/raft/storage_test.go b/manager/state/raft/storage_test.go index 70d08b81ba..2602a573af 100644 --- a/manager/state/raft/storage_test.go +++ b/manager/state/raft/storage_test.go @@ -271,10 +271,10 @@ func TestRaftSnapshotForceNewCluster(t *testing.T) { // Use gRPC instead of calling handler directly because of // authorization check. - client, err := nodes[1].ConnectToMember(nodes[1].Address, 10*time.Second) + cc, err := dial(nodes[1], nodes[1].Address) assert.NoError(t, err) - raftClient := api.NewRaftMembershipClient(client.Conn) - defer client.Conn.Close() + raftClient := api.NewRaftMembershipClient(cc) + defer cc.Close() ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) resp, err := raftClient.Leave(ctx, &api.LeaveRequest{Node: &api.RaftMember{RaftID: nodes[2].Config.ID}}) assert.NoError(t, err, "error sending message to leave the raft") diff --git a/manager/state/raft/testutils/testutils.go b/manager/state/raft/testutils/testutils.go index 2c34b44d23..11f7827fe0 100644 --- a/manager/state/raft/testutils/testutils.go +++ b/manager/state/raft/testutils/testutils.go @@ -300,7 +300,6 @@ func NewNode(t *testing.T, clockSource *fakeclock.FakeClock, tc *cautils.TestCA, Config: cfg, StateDir: stateDir, ClockSource: clockSource, - SendTimeout: 10 * time.Second, TLSCredentials: securityConfig.ClientTLSCreds, KeyRotator: keyRotator, } @@ -414,7 +413,7 @@ func CopyNode(t *testing.T, clockSource *fakeclock.FakeClock, oldNode *TestNode, StateDir: oldNode.StateDir, ForceNewCluster: forceNewCluster, ClockSource: clockSource, - SendTimeout: 10 * time.Second, + SendTimeout: 2 * time.Second, TLSCredentials: securityConfig.ClientTLSCreds, KeyRotator: kr, } diff --git a/manager/state/raft/transport/mock_raft_test.go b/manager/state/raft/transport/mock_raft_test.go new file mode 100644 index 0000000000..7e7256de6a --- /dev/null +++ b/manager/state/raft/transport/mock_raft_test.go @@ -0,0 +1,168 @@ +package transport + +import ( + "net" + "time" + + "golang.org/x/net/context" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/manager/health" + "github.com/docker/swarmkit/manager/state/raft/membership" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +type snapshotReport struct { + id uint64 + status raft.SnapshotStatus +} + +type updateInfo struct { + id uint64 + addr string +} + +type mockRaft struct { + lis net.Listener + s *grpc.Server + tr *Transport + + nodeRemovedSignal chan struct{} + + removed map[uint64]bool + + processedMessages chan *raftpb.Message + processedSnapshots chan snapshotReport + + reportedUnreachables chan uint64 + updatedNodes chan updateInfo +} + +func newMockRaft() (*mockRaft, error) { + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return nil, err + } + mr := &mockRaft{ + lis: l, + s: grpc.NewServer(), + removed: make(map[uint64]bool), + nodeRemovedSignal: make(chan struct{}), + processedMessages: make(chan *raftpb.Message, 4096), + processedSnapshots: make(chan snapshotReport, 4096), + reportedUnreachables: make(chan uint64, 4096), + updatedNodes: make(chan updateInfo, 4096), + } + cfg := &Config{ + HeartbeatInterval: 3 * time.Second, + SendTimeout: 2 * time.Second, + Raft: mr, + } + tr := New(cfg) + mr.tr = tr + hs := health.NewHealthServer() + hs.SetServingStatus("Raft", api.HealthCheckResponse_SERVING) + api.RegisterRaftServer(mr.s, mr) + api.RegisterHealthServer(mr.s, hs) + go mr.s.Serve(l) + return mr, nil +} + +func (r *mockRaft) Addr() string { + return r.lis.Addr().String() +} + +func (r *mockRaft) Stop() { + r.tr.Stop() + r.s.Stop() +} + +func (r *mockRaft) RemovePeer(id uint64) error { + r.removed[id] = true + return r.tr.RemovePeer(id) +} + +func (r *mockRaft) ProcessRaftMessage(ctx context.Context, req *api.ProcessRaftMessageRequest) (*api.ProcessRaftMessageResponse, error) { + if r.removed[req.Message.From] { + return nil, grpc.Errorf(codes.NotFound, "%s", membership.ErrMemberRemoved.Error()) + } + r.processedMessages <- req.Message + return &api.ProcessRaftMessageResponse{}, nil +} + +func (r *mockRaft) ResolveAddress(ctx context.Context, req *api.ResolveAddressRequest) (*api.ResolveAddressResponse, error) { + addr, err := r.tr.PeerAddr(req.RaftID) + if err != nil { + return nil, err + } + return &api.ResolveAddressResponse{ + Addr: addr, + }, nil +} + +func (r *mockRaft) ReportUnreachable(id uint64) { + r.reportedUnreachables <- id +} + +func (r *mockRaft) IsIDRemoved(id uint64) bool { + return r.removed[id] +} + +func (r *mockRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) { + r.processedSnapshots <- snapshotReport{ + id: id, + status: status, + } +} + +func (r *mockRaft) UpdateNode(id uint64, addr string) { + r.updatedNodes <- updateInfo{ + id: id, + addr: addr, + } +} + +func (r *mockRaft) NodeRemoved() { + close(r.nodeRemovedSignal) +} + +type mockCluster struct { + rafts map[uint64]*mockRaft +} + +func newCluster() *mockCluster { + return &mockCluster{ + rafts: make(map[uint64]*mockRaft), + } +} + +func (c *mockCluster) Stop() { + for _, r := range c.rafts { + r.s.Stop() + } +} + +func (c *mockCluster) Add(id uint64) error { + mr, err := newMockRaft() + if err != nil { + return err + } + for otherID, otherRaft := range c.rafts { + if err := mr.tr.AddPeer(otherID, otherRaft.Addr()); err != nil { + return err + } + if err := otherRaft.tr.AddPeer(id, mr.Addr()); err != nil { + return err + } + } + c.rafts[id] = mr + return nil +} + +func (c *mockCluster) Get(id uint64) *mockRaft { + return c.rafts[id] +} diff --git a/manager/state/raft/transport/peer.go b/manager/state/raft/transport/peer.go new file mode 100644 index 0000000000..4db7a141ab --- /dev/null +++ b/manager/state/raft/transport/peer.go @@ -0,0 +1,299 @@ +package transport + +import ( + "fmt" + "sync" + "time" + + "golang.org/x/net/context" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/log" + "github.com/docker/swarmkit/manager/state/raft/membership" + "github.com/pkg/errors" +) + +type peer struct { + id uint64 + + tr *Transport + + msgc chan raftpb.Message + + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + mu sync.Mutex + cc *grpc.ClientConn + addr string + newAddr string + + active bool + becameActive time.Time +} + +func newPeer(id uint64, addr string, tr *Transport) (*peer, error) { + cc, err := tr.dial(addr) + if err != nil { + return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr) + } + ctx, cancel := context.WithCancel(tr.ctx) + ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id)) + p := &peer{ + id: id, + addr: addr, + cc: cc, + tr: tr, + ctx: ctx, + cancel: cancel, + msgc: make(chan raftpb.Message, 4096), + done: make(chan struct{}), + } + go p.run(ctx) + return p, nil +} + +func (p *peer) send(m raftpb.Message) (err error) { + p.mu.Lock() + defer func() { + if err != nil { + p.active = false + p.becameActive = time.Time{} + } + p.mu.Unlock() + }() + select { + case <-p.ctx.Done(): + return p.ctx.Err() + default: + } + select { + case p.msgc <- m: + case <-p.ctx.Done(): + return p.ctx.Err() + default: + p.tr.config.ReportUnreachable(p.id) + return errors.Errorf("peer is unreachable") + } + return nil +} + +func (p *peer) update(addr string) error { + p.mu.Lock() + defer p.mu.Unlock() + if p.addr == addr { + return nil + } + cc, err := p.tr.dial(addr) + if err != nil { + return err + } + + p.cc.Close() + p.cc = cc + p.addr = addr + return nil +} + +func (p *peer) updateAddr(addr string) error { + p.mu.Lock() + defer p.mu.Unlock() + if p.addr == addr { + return nil + } + log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr) + p.newAddr = addr + return nil +} + +func (p *peer) conn() *grpc.ClientConn { + p.mu.Lock() + defer p.mu.Unlock() + return p.cc +} + +func (p *peer) address() string { + p.mu.Lock() + defer p.mu.Unlock() + return p.addr +} + +func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) { + resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id}) + if err != nil { + return "", errors.Wrap(err, "failed to resolve address") + } + return resp.Addr, nil +} + +func (p *peer) reportSnapshot(failure bool) { + if failure { + p.tr.config.ReportSnapshot(p.id, raft.SnapshotFailure) + return + } + p.tr.config.ReportSnapshot(p.id, raft.SnapshotFinish) +} + +func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error { + ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) + defer cancel() + _, err := api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m}) + if grpc.Code(err) == codes.NotFound && grpc.ErrorDesc(err) == membership.ErrMemberRemoved.Error() { + p.tr.config.NodeRemoved() + } + if m.Type == raftpb.MsgSnap { + if err != nil { + p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure) + } else { + } + } + p.reportSnapshot(err != nil) + if err != nil { + p.tr.config.ReportUnreachable(m.To) + return err + } + return nil +} + +func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error { + resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"}) + if err != nil { + return errors.Wrap(err, "failed to check health") + } + if resp.Status != api.HealthCheckResponse_SERVING { + return errors.Errorf("health check returned status %s", resp.Status) + } + return nil +} + +func (p *peer) healthCheck(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) + defer cancel() + return healthCheckConn(ctx, p.conn()) +} + +func (p *peer) setActive() { + p.mu.Lock() + if !p.active { + p.active = true + p.becameActive = time.Now() + } + p.mu.Unlock() +} + +func (p *peer) setInactive() { + p.mu.Lock() + p.active = false + p.becameActive = time.Time{} + p.mu.Unlock() +} + +func (p *peer) activeTime() time.Time { + p.mu.Lock() + defer p.mu.Unlock() + return p.becameActive +} + +func (p *peer) drain() error { + ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second) + defer cancel() + for { + select { + case m, ok := <-p.msgc: + if !ok { + // all messages proceeded + return nil + } + if err := p.sendProcessMessage(ctx, m); err != nil { + return errors.Wrap(err, "send drain message") + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (p *peer) handleAddressChange(ctx context.Context) error { + p.mu.Lock() + newAddr := p.newAddr + p.newAddr = "" + p.mu.Unlock() + if newAddr == "" { + return nil + } + cc, err := p.tr.dial(newAddr) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) + defer cancel() + if err := healthCheckConn(ctx, cc); err != nil { + cc.Close() + return err + } + // there is possibility of race if host changing address too fast, but + // it's unlikely and eventually thing should be settled + p.mu.Lock() + p.cc.Close() + p.cc = cc + p.addr = newAddr + p.tr.config.UpdateNode(p.id, p.addr) + p.mu.Unlock() + return nil +} + +func (p *peer) run(ctx context.Context) { + defer func() { + p.mu.Lock() + p.active = false + p.becameActive = time.Time{} + // at this point we can be sure that nobody will write to msgc + if p.msgc != nil { + close(p.msgc) + } + p.mu.Unlock() + if err := p.drain(); err != nil { + log.G(ctx).WithError(err).Error("failed to drain message queue") + } + close(p.done) + }() + if err := p.healthCheck(ctx); err == nil { + p.setActive() + } + for { + select { + case <-ctx.Done(): + return + default: + } + + select { + case m := <-p.msgc: + // we do not propagate context here, because this operation should be finished + // or timed out for correct raft work. + err := p.sendProcessMessage(context.Background(), m) + if err != nil { + log.G(ctx).WithError(err).Errorf("failed to send message %s", m.Type) + p.setInactive() + if err := p.handleAddressChange(ctx); err != nil { + log.G(ctx).WithError(err).Error("failed to change address after failure") + } + continue + } + p.setActive() + case <-ctx.Done(): + return + } + } +} + +func (p *peer) stop() { + p.cancel() + <-p.done +} diff --git a/manager/state/raft/transport/transport.go b/manager/state/raft/transport/transport.go new file mode 100644 index 0000000000..c3f86ee522 --- /dev/null +++ b/manager/state/raft/transport/transport.go @@ -0,0 +1,381 @@ +// Package transport provides grpc transport layer for raft. +// All methods are non-blocking. +package transport + +import ( + "sync" + "time" + + "golang.org/x/net/context" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/docker/swarmkit/log" + "github.com/pkg/errors" +) + +// ErrIsNotFound indicates that peer was never added to transport. +var ErrIsNotFound = errors.New("peer not found") + +// Raft is interface which represents Raft API for transport package. +type Raft interface { + ReportUnreachable(id uint64) + ReportSnapshot(id uint64, status raft.SnapshotStatus) + IsIDRemoved(id uint64) bool + UpdateNode(id uint64, addr string) + + NodeRemoved() +} + +// Config for Transport +type Config struct { + HeartbeatInterval time.Duration + SendTimeout time.Duration + Credentials credentials.TransportCredentials + RaftID string + + Raft +} + +// Transport is structure which manages remote raft peers and sends messages +// to them. +type Transport struct { + config *Config + + unknownc chan raftpb.Message + + mu sync.Mutex + peers map[uint64]*peer + stopped bool + + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + deferredConns map[*grpc.ClientConn]*time.Timer +} + +// New returns new Transport with specified Config. +func New(cfg *Config) *Transport { + ctx, cancel := context.WithCancel(context.Background()) + if cfg.RaftID != "" { + ctx = log.WithField(ctx, "raft_id", cfg.RaftID) + } + t := &Transport{ + peers: make(map[uint64]*peer), + config: cfg, + unknownc: make(chan raftpb.Message), + done: make(chan struct{}), + ctx: ctx, + cancel: cancel, + + deferredConns: make(map[*grpc.ClientConn]*time.Timer), + } + go t.run(ctx) + return t +} + +func (t *Transport) run(ctx context.Context) { + defer func() { + log.G(ctx).Debug("stop transport") + t.mu.Lock() + defer t.mu.Unlock() + t.stopped = true + for _, p := range t.peers { + p.stop() + } + for cc, timer := range t.deferredConns { + timer.Stop() + cc.Close() + } + t.deferredConns = nil + close(t.done) + }() + for { + select { + case <-ctx.Done(): + return + default: + } + + select { + case m := <-t.unknownc: + if err := t.sendUnknownMessage(ctx, m); err != nil { + log.G(ctx).WithError(err).Warnf("ignored message %s to unknown peer %x", m.Type, m.To) + } + case <-ctx.Done(): + return + } + } +} + +// Stop stops transport and waits until it finished +func (t *Transport) Stop() { + t.cancel() + <-t.done +} + +// Send sends raft message to remote peers. +func (t *Transport) Send(m raftpb.Message) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped { + return errors.New("transport stopped") + } + if t.config.IsIDRemoved(m.To) { + return errors.Errorf("refusing to send message %s to removed member %x", m.Type, m.To) + } + p, ok := t.peers[m.To] + if !ok { + log.G(t.ctx).Warningf("sending message %s to an unrecognized member ID %x", m.Type, m.To) + select { + // we need to process messages to unknown peers in separate goroutine + // to not block sender + case t.unknownc <- m: + case <-t.ctx.Done(): + return t.ctx.Err() + default: + return errors.New("unknown messages queue is full") + } + return nil + } + if err := p.send(m); err != nil { + return errors.Wrapf(err, "failed to send message %x to %x", m.Type, m.To) + } + return nil +} + +// AddPeer adds new peer with id and address addr to Transport. +// If there is already peer with such id in Transport it will return error if +// address is different (UpdatePeer should be used) or nil otherwise. +func (t *Transport) AddPeer(id uint64, addr string) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped { + return errors.New("transport stopped") + } + if ep, ok := t.peers[id]; ok { + if ep.address() == addr { + return nil + } + return errors.Errorf("peer %x already added with addr %s", id, ep.addr) + } + log.G(t.ctx).Debugf("transport: add peer %x with address %s", id, addr) + p, err := newPeer(id, addr, t) + if err != nil { + return errors.Wrapf(err, "failed to create peer %x with addr %s", id, addr) + } + t.peers[id] = p + return nil +} + +// RemovePeer removes peer from Transport and wait for it to stop. +func (t *Transport) RemovePeer(id uint64) error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.stopped { + return errors.New("transport stopped") + } + p, ok := t.peers[id] + if !ok { + return ErrIsNotFound + } + delete(t.peers, id) + cc := p.conn() + p.stop() + timer := time.AfterFunc(8*time.Second, func() { + t.mu.Lock() + if !t.stopped { + delete(t.deferredConns, cc) + cc.Close() + } + t.mu.Unlock() + }) + // store connection and timer for cleaning up on stop + t.deferredConns[cc] = timer + + return nil +} + +// UpdatePeer updates peer with new address. It replaces connection immediately. +func (t *Transport) UpdatePeer(id uint64, addr string) error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.stopped { + return errors.New("transport stopped") + } + p, ok := t.peers[id] + if !ok { + return ErrIsNotFound + } + if err := p.update(addr); err != nil { + return err + } + log.G(t.ctx).Debugf("peer %x updated to address %s", id, addr) + return nil +} + +// UpdatePeerAddr updates peer with new address, but delays connection creation. +// New address won't be used until first failure on old address. +func (t *Transport) UpdatePeerAddr(id uint64, addr string) error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.stopped { + return errors.New("transport stopped") + } + p, ok := t.peers[id] + if !ok { + return ErrIsNotFound + } + if err := p.updateAddr(addr); err != nil { + return err + } + return nil +} + +// PeerConn returns raw grpc connection to peer. +func (t *Transport) PeerConn(id uint64) (*grpc.ClientConn, error) { + t.mu.Lock() + defer t.mu.Unlock() + p, ok := t.peers[id] + if !ok { + return nil, ErrIsNotFound + } + p.mu.Lock() + active := p.active + p.mu.Unlock() + if !active { + return nil, errors.New("peer is inactive") + } + return p.conn(), nil +} + +// PeerAddr returns address of peer with id. +func (t *Transport) PeerAddr(id uint64) (string, error) { + t.mu.Lock() + defer t.mu.Unlock() + p, ok := t.peers[id] + if !ok { + return "", ErrIsNotFound + } + return p.address(), nil +} + +// HealthCheck checks health of particular peer. +func (t *Transport) HealthCheck(ctx context.Context, id uint64) error { + t.mu.Lock() + p, ok := t.peers[id] + t.mu.Unlock() + if !ok { + return ErrIsNotFound + } + ctx, cancel := t.withContext(ctx) + defer cancel() + return p.healthCheck(ctx) +} + +// Active returns true if node was recently active and false otherwise. +func (t *Transport) Active(id uint64) bool { + t.mu.Lock() + defer t.mu.Unlock() + p, ok := t.peers[id] + if !ok { + return false + } + p.mu.Lock() + active := p.active + p.mu.Unlock() + return active +} + +func (t *Transport) longestActive() (*peer, error) { + var longest *peer + var longestTime time.Time + t.mu.Lock() + defer t.mu.Unlock() + for _, p := range t.peers { + becameActive := p.activeTime() + if becameActive.IsZero() { + continue + } + if longest == nil { + longest = p + continue + } + if becameActive.Before(longestTime) { + longest = p + longestTime = becameActive + } + } + if longest == nil { + return nil, errors.New("failed to find longest active peer") + } + return longest, nil +} + +func (t *Transport) dial(addr string) (*grpc.ClientConn, error) { + grpcOptions := []grpc.DialOption{ + grpc.WithBackoffMaxDelay(8 * time.Second), + } + if t.config.Credentials != nil { + grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(t.config.Credentials)) + } else { + grpcOptions = append(grpcOptions, grpc.WithInsecure()) + } + + if t.config.SendTimeout > 0 { + grpcOptions = append(grpcOptions, grpc.WithTimeout(t.config.SendTimeout)) + } + + cc, err := grpc.Dial(addr, grpcOptions...) + if err != nil { + return nil, err + } + + return cc, nil +} + +func (t *Transport) withContext(ctx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + + go func() { + select { + case <-ctx.Done(): + case <-t.ctx.Done(): + cancel() + } + }() + return ctx, cancel +} + +func (t *Transport) resolvePeer(ctx context.Context, id uint64) (*peer, error) { + longestActive, err := t.longestActive() + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(ctx, t.config.SendTimeout) + defer cancel() + addr, err := longestActive.resolveAddr(ctx, id) + if err != nil { + return nil, err + } + return newPeer(id, addr, t) +} + +func (t *Transport) sendUnknownMessage(ctx context.Context, m raftpb.Message) error { + p, err := t.resolvePeer(ctx, m.To) + if err != nil { + return errors.Wrapf(err, "failed to resolve peer") + } + defer p.cancel() + if err := p.sendProcessMessage(ctx, m); err != nil { + return errors.Wrapf(err, "failed to send message") + } + return nil +} diff --git a/manager/state/raft/transport/transport_test.go b/manager/state/raft/transport/transport_test.go new file mode 100644 index 0000000000..2989469699 --- /dev/null +++ b/manager/state/raft/transport/transport_test.go @@ -0,0 +1,286 @@ +package transport + +import ( + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func sendMessages(ctx context.Context, c *mockCluster, from uint64, to []uint64, msgType raftpb.MessageType) error { + var firstErr error + for _, id := range to { + err := c.Get(from).tr.Send(raftpb.Message{ + Type: msgType, + From: from, + To: id, + }) + if firstErr == nil { + firstErr = err + } + } + return firstErr +} + +func testSend(ctx context.Context, c *mockCluster, from uint64, to []uint64, msgType raftpb.MessageType) func(*testing.T) { + return func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + require.NoError(t, sendMessages(ctx, c, from, to, msgType)) + + for _, id := range to { + select { + case msg := <-c.Get(id).processedMessages: + assert.Equal(t, msg.To, id) + assert.Equal(t, msg.From, from) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + } + + if msgType == raftpb.MsgSnap { + var snaps []snapshotReport + for i := 0; i < len(to); i++ { + select { + case snap := <-c.Get(from).processedSnapshots: + snaps = append(snaps, snap) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + } + loop: + for _, id := range to { + for _, s := range snaps { + if s.id == id { + assert.Equal(t, s.status, raft.SnapshotFinish) + continue loop + } + } + t.Fatalf("shapshot ot %d is not reported", id) + } + } + } +} + +func TestSend(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + t.Run("Send Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) + t.Run("Send_Snapshot_Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgSnap)) +} + +func TestSendRemoved(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + require.NoError(t, c.Get(1).RemovePeer(2)) + + err := sendMessages(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup) + require.Error(t, err) + require.Contains(t, err.Error(), "to removed member") +} + +func TestSendSnapshotFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + + // stop peer server to emulate error + c.Get(2).s.Stop() + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + require.NoError(t, sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgSnap)) + + select { + case snap := <-c.Get(1).processedSnapshots: + assert.Equal(t, snap.id, uint64(2)) + assert.Equal(t, snap.status, raft.SnapshotFailure) + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } + + select { + case id := <-c.Get(1).reportedUnreachables: + assert.Equal(t, id, uint64(2)) + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } +} + +func TestSendUnknown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + // remove peer from 1 transport to make it "unknown" to it + oldPeer := c.Get(1).tr.peers[2] + delete(c.Get(1).tr.peers, 2) + oldPeer.cancel() + <-oldPeer.done + + // give peers time to mark each other as active + time.Sleep(1 * time.Second) + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + require.NoError(t, sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgHup)) + + select { + case msg := <-c.Get(2).processedMessages: + assert.Equal(t, msg.To, uint64(2)) + assert.Equal(t, msg.From, uint64(1)) + case <-msgCtx.Done(): + t.Fatal(msgCtx.Err()) + } +} + +func TestUpdatePeerAddr(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + t.Run("Send Message Before Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) + + nr, err := newMockRaft() + require.NoError(t, err) + + c.Get(3).Stop() + c.rafts[3] = nr + + require.NoError(t, c.Get(1).tr.UpdatePeer(3, nr.Addr())) + require.NoError(t, c.Get(1).tr.UpdatePeer(3, nr.Addr())) + + t.Run("Send Message After Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) +} + +func TestUpdatePeerAddrDelayed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + t.Run("Send Message Before Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) + + nr, err := newMockRaft() + require.NoError(t, err) + + c.Get(3).Stop() + c.rafts[3] = nr + + require.NoError(t, c.Get(1).tr.UpdatePeerAddr(3, nr.Addr())) + + // initiate failure to replace connection, and wait for it + sendMessages(ctx, c, 1, []uint64{3}, raftpb.MsgHup) + updateCtx, updateCancel := context.WithTimeout(ctx, 4*time.Second) + defer updateCancel() + select { + case update := <-c.Get(1).updatedNodes: + require.Equal(t, update.id, uint64(3)) + require.Equal(t, update.addr, nr.Addr()) + case <-updateCtx.Done(): + t.Fatal(updateCtx.Err()) + } + + t.Run("Send Message After Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) +} + +func TestSendUnreachable(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + + // set channel to nil to emulate full queue + // we need to reset some fields after cancel + p2 := c.Get(1).tr.peers[2] + p2.cancel() + <-p2.done + p2.msgc = nil + p2.done = make(chan struct{}) + p2.ctx = ctx + go p2.run(ctx) + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + err := sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgSnap) + require.Error(t, err) + require.Contains(t, err.Error(), "peer is unreachable") + select { + case id := <-c.Get(1).reportedUnreachables: + assert.Equal(t, id, uint64(2)) + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } +} + +func TestSendNodeRemoved(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster() + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + + require.NoError(t, c.Get(1).RemovePeer(2)) + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + require.NoError(t, sendMessages(msgCtx, c, 2, []uint64{1}, raftpb.MsgSnap)) + select { + case <-c.Get(2).nodeRemovedSignal: + case <-msgCtx.Done(): + t.Fatal(msgCtx.Err()) + } +}