Skip to content

Commit

Permalink
advancedtls: populate verified chains when using custom buildVerifyFu…
Browse files Browse the repository at this point in the history
…nc (#7181)

* populate verified chains when using custom buildVerifyFunc
  • Loading branch information
mudhireddy authored May 22, 2024
1 parent 1db6590 commit 5ffe0ef
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 6 deletions.
27 changes: 21 additions & 6 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ import (
credinternal "google.golang.org/grpc/internal/credentials"
)

type CertificateChains [][]*x509.Certificate

// HandshakeVerificationInfo contains information about a handshake needed for
// verification for use when implementing the `PostHandshakeVerificationFunc`
// The fields in this struct are read-only.
Expand All @@ -53,7 +55,7 @@ type HandshakeVerificationInfo struct {
RawCerts [][]byte
// The verification chain obtained by checking peer RawCerts against the
// trust certificate bundle(s), if applicable.
VerifiedChains [][]*x509.Certificate
VerifiedChains CertificateChains
// The leaf certificate sent from peer, if choosing to verify the peer
// certificate(s) and that verification passed. This field would be nil if
// either user chose not to verify or the verification failed.
Expand Down Expand Up @@ -552,7 +554,8 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
if cfg.ServerName == "" {
cfg.ServerName = authority
}
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn)
peerVerifiedChains := CertificateChains{}
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn, &peerVerifiedChains)
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
Expand All @@ -576,12 +579,14 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
},
}
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
info.State.VerifiedChains = peerVerifiedChains
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
}

func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
cfg := credinternal.CloneTLSConfig(c.config)
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
peerVerifiedChains := CertificateChains{}
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn, &peerVerifiedChains)
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
conn.Close()
Expand All @@ -594,6 +599,7 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti
},
}
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
info.State.VerifiedChains = peerVerifiedChains
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
}

Expand All @@ -618,9 +624,15 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
// 1. does not have a good support on root cert reloading.
// 2. will ignore basic certificate check when setting InsecureSkipVerify
// to true.
//
// peerVerifiedChains(output param): verified chain of certs from leaf to the
// trust cert that the peer trusts.
// 1. For server it is, client certs + Root ca that the server trusts
// 2. For client it is, server certs + Root ca that the client trusts
func buildVerifyFunc(c *advancedTLSCreds,
serverName string,
rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
rawConn net.Conn,
peerVerifiedChains *CertificateChains) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
chains := verifiedChains
var leafCert *x509.Certificate
Expand Down Expand Up @@ -684,7 +696,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
if c.revocationOptions != nil {
verifiedChains := chains
if verifiedChains == nil {
verifiedChains = [][]*x509.Certificate{rawCertList}
verifiedChains = CertificateChains{rawCertList}
}
if err := checkChainRevocation(verifiedChains, *c.revocationOptions); err != nil {
return err
Expand All @@ -698,8 +710,11 @@ func buildVerifyFunc(c *advancedTLSCreds,
VerifiedChains: chains,
Leaf: leafCert,
})
return err
if err != nil {
return err
}
}
*peerVerifiedChains = chains
return nil
}
}
Expand Down
71 changes: 71 additions & 0 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package advancedtls

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -949,6 +950,76 @@ func (s) TestClientServerHandshake(t *testing.T) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
clientAuthInfo, serverAuthInfo)
}
serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains
if test.serverMutualTLS && !test.serverExpectError {
if len(serverVerifiedChains) == 0 {
t.Fatalf("server verified chains is empty")
}
var clientCert *tls.Certificate
if len(test.clientCert) > 0 {
clientCert = &test.clientCert[0]
} else if test.clientGetCert != nil {
cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{})
clientCert = cert
} else if test.clientIdentityProvider != nil {
km, _ := test.clientIdentityProvider.KeyMaterial(context.TODO())
clientCert = &km.Certs[0]
}
if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) {
t.Fatal("server verifiedChains leaf cert doesn't match client cert")
}

var serverRoot *x509.CertPool
if test.serverRoot != nil {
serverRoot = test.serverRoot
} else if test.serverGetRoot != nil {
result, _ := test.serverGetRoot(&GetRootCAsParams{})
serverRoot = result.TrustCerts
} else if test.serverRootProvider != nil {
km, _ := test.serverRootProvider.KeyMaterial(context.TODO())
serverRoot = km.Roots
}
serverVerifiedChainsCp := x509.NewCertPool()
serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1])
if !serverVerifiedChainsCp.Equal(serverRoot) {
t.Fatalf("server verified chain hierarchy doesn't match")
}
}
clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains
if test.serverMutualTLS && !test.clientExpectHandshakeError {
if len(clientVerifiedChains) == 0 {
t.Fatalf("client verified chains is empty")
}
var serverCert *tls.Certificate
if len(test.serverCert) > 0 {
serverCert = &test.serverCert[0]
} else if test.serverGetCert != nil {
cert, _ := test.serverGetCert(&tls.ClientHelloInfo{})
serverCert = cert[0]
} else if test.serverIdentityProvider != nil {
km, _ := test.serverIdentityProvider.KeyMaterial(context.TODO())
serverCert = &km.Certs[0]
}
if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) {
t.Fatal("client verifiedChains leaf cert doesn't match server cert")
}

var clientRoot *x509.CertPool
if test.clientRoot != nil {
clientRoot = test.clientRoot
} else if test.clientGetRoot != nil {
result, _ := test.clientGetRoot(&GetRootCAsParams{})
clientRoot = result.TrustCerts
} else if test.clientRootProvider != nil {
km, _ := test.clientRootProvider.KeyMaterial(context.TODO())
clientRoot = km.Roots
}
clientVerifiedChainsCp := x509.NewCertPool()
clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1])
if !clientVerifiedChainsCp.Equal(clientRoot) {
t.Fatalf("client verified chain hierarchy doesn't match")
}
}
})
}
}
Expand Down

0 comments on commit 5ffe0ef

Please sign in to comment.