Skip to content

Commit

Permalink
Merge pull request crewjam#26 from edaniels/master
Browse files Browse the repository at this point in the history
Be SAML xsd:dateTime conformant; Update JWT pkg usage
  • Loading branch information
crewjam authored Jul 20, 2016
2 parents 6139d7a + 7d69af3 commit bb9cca9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
19 changes: 12 additions & 7 deletions samlsp/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler {

secretBlock, _ := pem.Decode([]byte(m.ServiceProvider.Key))
state := jwt.New(jwt.GetSigningMethod("HS256"))
state.Claims["id"] = req.ID
state.Claims["uri"] = r.URL.String()
claims := state.Claims.(jwt.MapClaims)
claims["id"] = req.ID
claims["uri"] = r.URL.String()
signedState, err := state.SignedString(secretBlock.Bytes)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -171,7 +172,8 @@ func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string {
log.Printf("... invalid token %s", err)
continue
}
rv = append(rv, token.Claims["id"].(string))
claims := token.Claims.(jwt.MapClaims)
rv = append(rv, claims["id"].(string))
}

// If IDP initiated requests are allowed, then we can expect an empty response ID.
Expand Down Expand Up @@ -205,7 +207,8 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
redirectURI = state.Claims["uri"].(string)
claims := state.Claims.(jwt.MapClaims)
redirectURI = claims["uri"].(string)

// delete the cookie
stateCookie.Value = ""
Expand All @@ -214,6 +217,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
}

token := jwt.New(jwt.GetSigningMethod("HS256"))
claims := token.Claims.(jwt.MapClaims)
for _, attr := range assertion.AttributeStatement.Attributes {
valueStrings := []string{}
for _, v := range attr.Values {
Expand All @@ -223,9 +227,9 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
if claimName == "" {
claimName = attr.Name
}
token.Claims[claimName] = valueStrings
claims[claimName] = valueStrings
}
token.Claims["exp"] = saml.TimeNow().Add(cookieMaxAge).Unix()
claims["exp"] = saml.TimeNow().Add(cookieMaxAge).Unix()
signedToken, err := token.SignedString(secretBlock.Bytes)
if err != nil {
panic(err)
Expand Down Expand Up @@ -275,7 +279,8 @@ func (m *Middleware) IsAuthorized(r *http.Request) bool {
}
}

for claimName, claimValue := range token.Claims {
claims := token.Claims.(jwt.MapClaims)
for claimName, claimValue := range claims {
if claimName == "exp" {
continue
}
Expand Down
2 changes: 1 addition & 1 deletion util.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

// TimeNow is a function that returns the current time. The default
// value is time.Now, but it can be replaced for testing.
var TimeNow = time.Now
var TimeNow = func() time.Time { return time.Now().UTC() }

// RandReader is the io.Reader that produces cryptographically random
// bytes when they are need by the library. The default value is
Expand Down

0 comments on commit bb9cca9

Please sign in to comment.