package internal import ( "context" "errors" "fmt" "net/http" "time" jwt "github.com/dgrijalva/jwt-go" "github.com/julienschmidt/httprouter" log "github.com/sirupsen/logrus" "git.mills.io/prologic/spyda/internal/passwords" ) // ContextKey ... type ContextKey int const ( TokenContextKey ContextKey = iota UserContextKey ) var ( // ErrInvalidCredentials is returned for invalid credentials against /auth ErrInvalidCredentials = errors.New("error: invalid credentials") // ErrInvalidToken is returned for expired or invalid tokens used in Authorizeation headers ErrInvalidToken = errors.New("error: invalid token") ) // API ... type API struct { router *Router config *Config db Store pm passwords.Passwords } // NewAPI ... func NewAPI(router *Router, config *Config, db Store, pm passwords.Passwords) *API { api := &API{router, config, db, pm} api.initRoutes() return api } func (a *API) initRoutes() { router := a.router.Group("/api/v1") router.GET("/ping", a.PingEndpoint()) } // CreateToken ... func (a *API) CreateToken(user *User, r *http.Request) (*Token, error) { claims := jwt.MapClaims{} claims["username"] = user.Username createdAt := time.Now() token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(a.config.APISigningKey)) if err != nil { log.WithError(err).Error("error creating signed token") return nil, err } signedToken, err := jwt.Parse(tokenString, a.jwtKeyFunc) if err != nil { log.WithError(err).Error("error creating signed token") return nil, err } tkn := &Token{ Signature: signedToken.Signature, Value: tokenString, UserAgent: r.UserAgent(), CreatedAt: createdAt, } return tkn, nil } func (a *API) jwtKeyFunc(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("There was an error") } return []byte(a.config.APISigningKey), nil } func (a *API) getLoggedInUser(r *http.Request) *User { token, err := jwt.Parse(r.Header.Get("Token"), a.jwtKeyFunc) if err != nil { return nil } if !token.Valid { return nil } claims := token.Claims.(jwt.MapClaims) username := claims["username"].(string) user, err := a.db.GetUser(username) if err != nil { log.WithError(err).Error("error loading user object") return nil } return user } func (a *API) isAuthorized(endpoint httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { if r.Header.Get("Token") == "" { http.Error(w, "No Token Provided", http.StatusUnauthorized) return } token, err := jwt.Parse(r.Header.Get("Token"), a.jwtKeyFunc) if err != nil { log.WithError(err).Error("error parsing token") http.Error(w, "Bad Request", http.StatusBadRequest) return } if token.Valid { claims := token.Claims.(jwt.MapClaims) username := claims["username"].(string) user, err := a.db.GetUser(username) if err != nil { log.WithError(err).Error("error loading user object") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } ctx := context.WithValue(r.Context(), TokenContextKey, token) ctx = context.WithValue(ctx, UserContextKey, user) endpoint(w, r.WithContext(ctx), p) } else { http.Error(w, "Invalid Token", http.StatusUnauthorized) return } } } // PingEndpoint ... func (a *API) PingEndpoint() httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{}`)) } }