Skip to content

Commit

Permalink
return non temporary connection error if dialer returns non temprary …
Browse files Browse the repository at this point in the history
…errors
  • Loading branch information
menghanl committed Nov 10, 2016
1 parent c098900 commit 947e436
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
2 changes: 2 additions & 0 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ func WithTimeout(d time.Duration) DialOption {
}

// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
// If an error is returned by the dial function, gRPC checks its Temporary() method to decide if
// re-dialing is necessary.
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
Expand Down
35 changes: 33 additions & 2 deletions clientconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
package grpc

import (
"net"
"testing"
"time"

Expand All @@ -45,8 +46,12 @@ import (

const tlsDir = "testdata/"

func temporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, error) {
return nil, &errorWithTemp{true} // Always return temporary error.
}

func TestDialTimeout(t *testing.T) {
conn, err := Dial("Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure())
conn, err := Dial("Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure(), WithDialer(temporaryErrorDialer))
if err == nil {
conn.Close()
}
Expand All @@ -60,7 +65,7 @@ func TestTLSDialTimeout(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock())
conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithDialer(temporaryErrorDialer))
if err == nil {
conn.Close()
}
Expand Down Expand Up @@ -188,3 +193,29 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt
}
conn.Close()
}

type errorWithTemp struct {
temp bool
}

func (e *errorWithTemp) Error() string {
return "non-temprary-error"
}

func (e *errorWithTemp) Temporary() bool {
return e.temp
}

var nonTemporaryError = &errorWithTemp{false}

func nonTemporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, error) {
return nil, nonTemporaryError
}

func TestDialWithBlockErrorOnNonTemporaryErrorDialer(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
_, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock())
if err != nonTemporaryError {
t.Fatalf("Dial(%q) = %v, want %v", "", err, nonTemporaryError)
}
}
18 changes: 11 additions & 7 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,10 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
}
te.srv.Stop()
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
_, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false))
if e.balancer && grpc.Code(err) != codes.DeadlineExceeded {
// If e.balancer == nil, the ac will stop reconnecting because the dialer returns non-temp error,
// the error will be an internal error.
t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %s", ctx, err, codes.DeadlineExceeded)
}
awaitNewConnLogOutput()
Expand Down Expand Up @@ -993,18 +996,19 @@ func testFailFast(t *testing.T, e env) {
// Loop until the server teardown is propagated to the client.
for {
_, err := tc.EmptyCall(context.Background(), &testpb.Empty{})
if grpc.Code(err) == codes.Unavailable {
if grpc.Code(err) == codes.Internal {
break
}
fmt.Printf("%v.EmptyCall(_, _) = _, %v", tc, err)
time.Sleep(10 * time.Millisecond)
}
// The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable.
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %s", err, codes.Unavailable)
// The client stops reconnecting because dial function returns non-temporary error,
// ongoing fail-fast RPCs should fail with code.Internal.
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %s", err, codes.Internal)
}
if _, err := tc.StreamingInputCall(context.Background()); grpc.Code(err) != codes.Unavailable {
t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %s", err, codes.Unavailable)
if _, err := tc.StreamingInputCall(context.Background()); grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %s", err, codes.Internal)
}

awaitNewConnLogOutput()
Expand Down
2 changes: 1 addition & 1 deletion transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
scheme := "http"
conn, err := dial(ctx, opts.Dialer, addr.Addr)
if err != nil {
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(isTemporary(err), err, "transport: %v", err)
}
// Any further errors will close the underlying connection
defer func(conn net.Conn) {
Expand Down

0 comments on commit 947e436

Please sign in to comment.