Skip to content

Commit

Permalink
Create PAT supporting auth middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsanbarkati committed Feb 28, 2023
1 parent 96267e2 commit 7973525
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 213 deletions.
42 changes: 21 additions & 21 deletions ee/query-service/app/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,65 +69,65 @@ func (ah *APIHandler) CheckFeature(f string) bool {
}

// RegisterRoutes registers routes for this handler on the given router
func (ah *APIHandler) RegisterRoutes(router *mux.Router) {
func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddleware) {
// note: add ee override methods first

// routes available only in ee version
router.HandleFunc("/api/v1/licenses",
baseapp.AdminAccess(ah.listLicenses)).
am.AdminAccess(ah.listLicenses)).
Methods(http.MethodGet)

router.HandleFunc("/api/v1/licenses",
baseapp.AdminAccess(ah.applyLicense)).
am.AdminAccess(ah.applyLicense)).
Methods(http.MethodPost)

router.HandleFunc("/api/v1/featureFlags",
baseapp.OpenAccess(ah.getFeatureFlags)).
am.OpenAccess(ah.getFeatureFlags)).
Methods(http.MethodGet)

router.HandleFunc("/api/v1/loginPrecheck",
baseapp.OpenAccess(ah.precheckLogin)).
am.OpenAccess(ah.precheckLogin)).
Methods(http.MethodGet)

// paid plans specific routes
router.HandleFunc("/api/v1/complete/saml",
baseapp.OpenAccess(ah.receiveSAML)).
am.OpenAccess(ah.receiveSAML)).
Methods(http.MethodPost)

router.HandleFunc("/api/v1/complete/google",
baseapp.OpenAccess(ah.receiveGoogleAuth)).
am.OpenAccess(ah.receiveGoogleAuth)).
Methods(http.MethodGet)

router.HandleFunc("/api/v1/orgs/{orgId}/domains",
baseapp.AdminAccess(ah.listDomainsByOrg)).
am.AdminAccess(ah.listDomainsByOrg)).
Methods(http.MethodGet)

router.HandleFunc("/api/v1/domains",
baseapp.AdminAccess(ah.postDomain)).
am.AdminAccess(ah.postDomain)).
Methods(http.MethodPost)

router.HandleFunc("/api/v1/domains/{id}",
baseapp.AdminAccess(ah.putDomain)).
am.AdminAccess(ah.putDomain)).
Methods(http.MethodPut)

router.HandleFunc("/api/v1/domains/{id}",
baseapp.AdminAccess(ah.deleteDomain)).
am.AdminAccess(ah.deleteDomain)).
Methods(http.MethodDelete)

// base overrides
router.HandleFunc("/api/v1/version", baseapp.OpenAccess(ah.getVersion)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/invite/{token}", baseapp.OpenAccess(ah.getInvite)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/register", baseapp.OpenAccess(ah.registerUser)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/login", baseapp.OpenAccess(ah.loginUser)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/traces/{traceId}", baseapp.ViewAccess(ah.searchTraces)).Methods(http.MethodGet)
router.HandleFunc("/api/v2/metrics/query_range", baseapp.ViewAccess(ah.queryRangeMetricsV2)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/version", am.OpenAccess(ah.getVersion)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/invite/{token}", am.OpenAccess(ah.getInvite)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/register", am.OpenAccess(ah.registerUser)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/login", am.OpenAccess(ah.loginUser)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/traces/{traceId}", am.ViewAccess(ah.searchTraces)).Methods(http.MethodGet)
router.HandleFunc("/api/v2/metrics/query_range", am.ViewAccess(ah.queryRangeMetricsV2)).Methods(http.MethodPost)

// PAT APIs
router.HandleFunc("/api/v1/pat", baseapp.SelfAccess(ah.createPAT)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/pat", baseapp.SelfAccess(ah.getPATs)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/pat/{id}", baseapp.SelfAccess(ah.deletePAT)).Methods(http.MethodDelete)
router.HandleFunc("/api/v1/pat", am.OpenAccess(ah.createPAT)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/pat", am.OpenAccess(ah.getPATs)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/pat/{id}", am.OpenAccess(ah.deletePAT)).Methods(http.MethodDelete)

ah.APIHandler.RegisterRoutes(router)
ah.APIHandler.RegisterRoutes(router, am)

}

Expand Down
30 changes: 29 additions & 1 deletion ee/query-service/app/api/pat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -55,7 +56,14 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {

func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
user, _ := auth.GetUserFromRequest(r)
user, err := auth.GetUserFromRequest(r)
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
zap.S().Infof("Get PATs for user: %+v", user.Id)
pats, apierr := ah.AppDao().ListPATs(ctx, user.Id)
if apierr != nil {
Expand All @@ -68,6 +76,26 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
id := mux.Vars(r)["id"]
user, err := auth.GetUserFromRequest(r)
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
pat, apierr := ah.AppDao().GetPAT(ctx, id)
if apierr != nil {
RespondError(w, apierr, nil)
return
}
if pat.UserID != user.Id {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: fmt.Errorf("unauthorized PAT delete request"),
}, nil)
return
}
zap.S().Infof("Delete PAT with id: %+v", id)
if apierr := ah.AppDao().DeletePAT(ctx, id); apierr != nil {
RespondError(w, apierr, nil)
Expand Down
53 changes: 50 additions & 3 deletions ee/query-service/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
_ "net/http/pprof" // http profiler
"os"
"strings"
"time"

"github.com/gorilla/handlers"
Expand All @@ -25,7 +26,9 @@ import (
licensepkg "go.signoz.io/signoz/ee/query-service/license"
"go.signoz.io/signoz/ee/query-service/usage"

baseapp "go.signoz.io/signoz/pkg/query-service/app"
"go.signoz.io/signoz/pkg/query-service/app/dashboards"
baseauth "go.signoz.io/signoz/pkg/query-service/auth"
baseconst "go.signoz.io/signoz/pkg/query-service/constants"
"go.signoz.io/signoz/pkg/query-service/healthcheck"
basealm "go.signoz.io/signoz/pkg/query-service/integrations/alertManager"
Expand Down Expand Up @@ -199,17 +202,61 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,
}, nil
}

func getPATToken(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil
}

authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", fmt.Errorf("authorization header format must be Bearer {token}")
}

return authHeaderParts[1], nil
}

func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, error) {

r := mux.NewRouter()

getUserFromPAT := func(r *http.Request) (*model.UserPayload, error) {
patToken, err := getPATToken(r)
if err != nil {
return nil, fmt.Errorf("failed to get PAT token in request headers, err: %v", err)
}
ctx := context.Background()
dao := apiHandler.AppDao()
pat, err := dao.GetPAT(ctx, patToken)
if err != nil {
return nil, fmt.Errorf("failed to fetch PAT token from DB, err %v", err)
}
user, apierr := dao.GetUser(ctx, pat.UserID)
if apierr != nil {
return nil, fmt.Errorf("failed to fetch user for PAT from DB, err: %v", apierr)
}
return user, nil
}

getUserFromRequest := func(r *http.Request) (*model.UserPayload, error) {
user, err := getUserFromPAT(r)
if err == nil && user != nil {
zap.S().Debugf("Found valid PAT user: %+v", user)
return user, nil
}
if err != nil {
zap.S().Debugf("Error while getting user for PAT: %+v", err)
}
return baseauth.GetUserFromRequest(r)
}
am := baseapp.NewAuthMiddleware(getUserFromRequest)
r.Use(setTimeoutMiddleware)
r.Use(s.analyticsMiddleware)
r.Use(loggingMiddleware)

apiHandler.RegisterRoutes(r)
apiHandler.RegisterMetricsRoutes(r)
apiHandler.RegisterLogsRoutes(r)
apiHandler.RegisterRoutes(r, am)
apiHandler.RegisterMetricsRoutes(r, am)
apiHandler.RegisterLogsRoutes(r, am)

c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
Expand Down
2 changes: 1 addition & 1 deletion ee/query-service/dao/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type ModelDao interface {
GetDomainByEmail(ctx context.Context, email string) (*model.OrgDomain, basemodel.BaseApiError)

CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError
GetPAT(ctx context.Context, patID string) (*model.PAT, basemodel.BaseApiError)
GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError)
ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError)
DeletePAT(ctx context.Context, id string) basemodel.BaseApiError
}
7 changes: 3 additions & 4 deletions ee/query-service/dao/sqlite/pat.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,17 @@ func (m *modelDao) DeletePAT(ctx context.Context, id string) basemodel.BaseApiEr
return nil
}

func (m *modelDao) GetPAT(ctx context.Context, patID string) (*model.PAT, basemodel.BaseApiError) {
func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemodel.BaseApiError) {
pats := []model.PAT{}

if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE id=?;`, patID); err != nil {
zap.S().Errorf("Failed to fetch PAT with ID: %s, err: %v", patID, zap.Error(err))
if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE token=?;`, token); err != nil {
return nil, model.InternalError(fmt.Errorf("failed to fetch PAT"))
}

if len(pats) != 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: fmt.Errorf("found zero or multiple PATS with same ID"),
Err: fmt.Errorf("found zero or multiple PATs with same token, %s", token),
}
}

Expand Down
116 changes: 116 additions & 0 deletions pkg/query-service/app/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package app

import (
"errors"
"net/http"

"github.com/gorilla/mux"
"go.signoz.io/signoz/pkg/query-service/auth"
"go.signoz.io/signoz/pkg/query-service/model"
)

type AuthMiddleware struct {
GetUserFromRequest func(r *http.Request) (*model.UserPayload, error)
}

func NewAuthMiddleware(f func(r *http.Request) (*model.UserPayload, error)) *AuthMiddleware {
return &AuthMiddleware{
GetUserFromRequest: f,
}
}

// func (am *AuthMiddleware) GetUserFromRequest(r *http.Request) (*model.UserPayload, error) {
// return auth.GetUserFromRequest(r)
// }

func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
f(w, r)
}
}

func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r)
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}

if !(auth.IsViewer(user) || auth.IsEditor(user) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to viewers/editors/admins."),
}, nil)
return
}
f(w, r)
}
}

func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r)
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !(auth.IsEditor(user) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to editors/admins."),
}, nil)
return
}
f(w, r)
}
}

func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r)
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
id := mux.Vars(r)["id"]
if !(auth.IsSelfAccessRequest(user, id) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible for self access or to the admins."),
}, nil)
return
}
f(w, r)
}
}

func (am *AuthMiddleware) AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r)
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !auth.IsAdmin(user) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to admins only"),
}, nil)
return
}
f(w, r)
}
}
Loading

0 comments on commit 7973525

Please sign in to comment.