Initial Codebase (untested)
This commit is contained in:
166
internal/session/manager.go
Normal file
166
internal/session/manager.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/andreadipersio/securecookie"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Key ...
|
||||
type Key int
|
||||
|
||||
const (
|
||||
SessionKey Key = iota
|
||||
)
|
||||
|
||||
// Options ...
|
||||
type Options struct {
|
||||
name string
|
||||
secret string
|
||||
secure bool
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
// NewOptions ...
|
||||
func NewOptions(name, secret string, secure bool, expiry time.Duration) *Options {
|
||||
return &Options{name, secret, secure, expiry}
|
||||
}
|
||||
|
||||
// Manager ...
|
||||
type Manager struct {
|
||||
options *Options
|
||||
store Store
|
||||
}
|
||||
|
||||
// NewManager ...
|
||||
func NewManager(options *Options, store Store) *Manager {
|
||||
return &Manager{options, store}
|
||||
}
|
||||
|
||||
// Create ...
|
||||
func (m *Manager) Create(w http.ResponseWriter) (*Session, error) {
|
||||
sid, err := NewSessionID(m.options.secret)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error creating new session")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: m.options.name,
|
||||
Value: sid.String(),
|
||||
Path: "/",
|
||||
Secure: m.options.secure,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(m.options.expiry.Seconds()),
|
||||
Expires: time.Now().Add(m.options.expiry),
|
||||
}
|
||||
|
||||
securecookie.SetSecureCookie(w, m.options.secret, cookie)
|
||||
|
||||
return &Session{
|
||||
store: m.store,
|
||||
|
||||
ID: sid.String(),
|
||||
Data: make(Map),
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(m.options.expiry),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate ....
|
||||
func (m *Manager) Validate(value string) (ID, error) {
|
||||
sessionID, err := ValidateSessionID(value, m.options.secret)
|
||||
return sessionID, err
|
||||
}
|
||||
|
||||
// GetOrCreate ...
|
||||
func (m *Manager) GetOrCreate(w http.ResponseWriter, r *http.Request) (*Session, error) {
|
||||
cookie, err := securecookie.GetSecureCookie(
|
||||
r,
|
||||
m.options.secret,
|
||||
m.options.name,
|
||||
)
|
||||
if err != nil {
|
||||
sess, err := m.Create(w)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error creating new session")
|
||||
return nil, err
|
||||
}
|
||||
if err = m.store.SetSession(sess.ID, sess); err != nil {
|
||||
log.WithError(err).Errorf("error creating new session for %s", sess.ID)
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
sid, err := m.Validate(cookie.Value)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error validating seesion")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sess, err := m.store.GetSession(sid.String())
|
||||
if err != nil {
|
||||
if err == ErrSessionNotFound {
|
||||
log.WithError(err).Warnf("no session found for %s (creating new one)", sid)
|
||||
m.Delete(w, r)
|
||||
|
||||
sess, err := m.Create(w)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error creating new session")
|
||||
return nil, err
|
||||
}
|
||||
if err = m.store.SetSession(sess.ID, sess); err != nil {
|
||||
log.WithError(err).Errorf("error creating new session for %s", sess.ID)
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
log.WithError(err).Errorf("error loading session for %s", sid)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Delete ...
|
||||
func (m *Manager) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
if sess := r.Context().Value(SessionKey); sess != nil {
|
||||
sess := sess.(*Session)
|
||||
if err := m.store.DelSession(sess.ID); err != nil {
|
||||
log.WithError(err).Warnf("error deleting session %s", sess.ID)
|
||||
}
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: m.options.name,
|
||||
Value: "",
|
||||
Secure: m.options.secure,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
MaxAge: -1,
|
||||
Expires: time.Now(),
|
||||
}
|
||||
|
||||
securecookie.SetSecureCookie(w, m.options.secret, cookie)
|
||||
}
|
||||
|
||||
// Handler ...
|
||||
func (m *Manager) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := m.GetOrCreate(w, r)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("session error")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), SessionKey, sess)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
67
internal/session/memorystore.go
Normal file
67
internal/session/memorystore.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// MemoryStore represents an in-memory session store.
|
||||
// This should be used only for testing and prototyping.
|
||||
// Production systems should use a shared server store like redis
|
||||
type MemoryStore struct {
|
||||
entries *cache.Cache
|
||||
}
|
||||
|
||||
// NewMemoryStore constructs and returns a new MemoryStore
|
||||
func NewMemoryStore(sessionDuration time.Duration) *MemoryStore {
|
||||
if sessionDuration < 0 {
|
||||
sessionDuration = DefaultSessionDuration
|
||||
}
|
||||
return &MemoryStore{
|
||||
entries: cache.New(sessionDuration, time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession ...
|
||||
func (s *MemoryStore) GetSession(sid string) (*Session, error) {
|
||||
val, found := s.entries.Get(sid)
|
||||
if !found {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
sess := val.(*Session)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// SetSession ...
|
||||
func (s *MemoryStore) SetSession(sid string, sess *Session) error {
|
||||
s.entries.Set(sid, sess, cache.DefaultExpiration)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasSession ...
|
||||
func (s *MemoryStore) HasSession(sid string) bool {
|
||||
_, ok := s.entries.Get(sid)
|
||||
return ok
|
||||
}
|
||||
|
||||
// DelSession ...
|
||||
func (s *MemoryStore) DelSession(sid string) error {
|
||||
s.entries.Delete(sid)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncSession ...
|
||||
func (s *MemoryStore) SyncSession(sess *Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllSessions ...
|
||||
func (s *MemoryStore) GetAllSessions() ([]*Session, error) {
|
||||
var sessions []*Session
|
||||
for _, item := range s.entries.Items() {
|
||||
sess := item.Object.(*Session)
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
67
internal/session/session.go
Normal file
67
internal/session/session.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Map ...
|
||||
type Map map[string]string
|
||||
|
||||
// Session ...
|
||||
type Session struct {
|
||||
store Store
|
||||
|
||||
ID string `json:"id"`
|
||||
Data Map `json:"data"`
|
||||
CreatedAt time.Time `json:"created"`
|
||||
ExpiresAt time.Time `json:"expires"`
|
||||
}
|
||||
|
||||
func NewSession(store Store) *Session {
|
||||
return &Session{store: store}
|
||||
}
|
||||
|
||||
func LoadSession(data []byte, sess *Session) error {
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sess.Data == nil {
|
||||
sess.Data = make(Map)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sess *Session) Expired() bool {
|
||||
return sess.ExpiresAt.Before(time.Now())
|
||||
}
|
||||
|
||||
func (sess *Session) Set(key, val string) error {
|
||||
sess.Data[key] = val
|
||||
return sess.store.SyncSession(sess)
|
||||
}
|
||||
|
||||
func (sess *Session) Get(key string) (val string, ok bool) {
|
||||
val, ok = sess.Data[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (sess *Session) Has(key string) bool {
|
||||
_, ok := sess.Data[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (sess *Session) Del(key string) error {
|
||||
delete(sess.Data, key)
|
||||
return sess.store.SyncSession(sess)
|
||||
}
|
||||
|
||||
func (sess *Session) Bytes() ([]byte, error) {
|
||||
data, err := json.Marshal(sess)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
65
internal/session/sid.go
Normal file
65
internal/session/sid.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// InvalidSessionID represents an empty, invalid session ID
|
||||
const InvalidSessionID ID = ""
|
||||
|
||||
const idLength = 32
|
||||
const signedLength = idLength + sha256.Size
|
||||
|
||||
// ID represents a valid, digitally-signed session ID
|
||||
type ID string
|
||||
|
||||
// ErrInvalidID is returned when an invalid session id is passed to ValidateID()
|
||||
var ErrInvalidID = errors.New("Invalid Session ID")
|
||||
|
||||
// NewSessionID creates and returns a new digitally-signed session ID,
|
||||
// using `signingKey` as the HMAC signing key. An error is returned only
|
||||
// if there was an error generating random bytes for the session ID
|
||||
func NewSessionID(signingKey string) (ID, error) {
|
||||
buf := make([]byte, signedLength)
|
||||
_, err := rand.Read(buf[:idLength])
|
||||
if err != nil {
|
||||
return InvalidSessionID, err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
_, _ = mac.Write(buf[:idLength])
|
||||
sig := mac.Sum(nil)
|
||||
copy(buf[idLength:], sig)
|
||||
|
||||
return ID(base64.URLEncoding.EncodeToString(buf)), nil
|
||||
}
|
||||
|
||||
// ValidateSessionID validates the `id` parameter using the `signingKey`
|
||||
// and returns an error if invalid, or a SignedID if valid
|
||||
func ValidateSessionID(id string, signingKey string) (ID, error) {
|
||||
buf, err := base64.URLEncoding.DecodeString(id)
|
||||
if err != nil {
|
||||
return InvalidSessionID, err
|
||||
}
|
||||
|
||||
if len(buf) < signedLength {
|
||||
return InvalidSessionID, ErrInvalidID
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
_, _ = mac.Write(buf[:idLength])
|
||||
messageMAC := mac.Sum(nil)
|
||||
if !hmac.Equal(messageMAC, buf[idLength:]) {
|
||||
return InvalidSessionID, ErrInvalidID
|
||||
}
|
||||
|
||||
return ID(id), nil
|
||||
}
|
||||
|
||||
func (sid ID) String() string {
|
||||
return string(sid)
|
||||
}
|
||||
75
internal/session/sid_test.go
Normal file
75
internal/session/sid_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testSigningKey = "a very secret key"
|
||||
|
||||
func TestNewID(t *testing.T) {
|
||||
sid, err := NewSessionID(testSigningKey)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if 0 == len(sid) {
|
||||
t.Errorf("Signed ID string was empty")
|
||||
}
|
||||
|
||||
_, err = ValidateSessionID(sid.String(), testSigningKey)
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInvalidKey(t *testing.T) {
|
||||
sid, err := NewSessionID(testSigningKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = ValidateSessionID(sid.String(), "some other signing key")
|
||||
if nil == err {
|
||||
t.Errorf("Was able to validate with incorrect signign key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModified(t *testing.T) {
|
||||
sid, err := NewSessionID(testSigningKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
runes := []rune(sid.String())
|
||||
runes[0]++
|
||||
modsid := string(runes)
|
||||
|
||||
_, err = ValidateSessionID(modsid, testSigningKey)
|
||||
if nil == err {
|
||||
t.Errorf("Was able to validate modified encoded string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyID(t *testing.T) {
|
||||
_, err := ValidateSessionID("", testSigningKey)
|
||||
if err == nil {
|
||||
t.Error("Able to validate empty key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadKey(t *testing.T) {
|
||||
buf := make([]byte, signedLength)
|
||||
if _, err := rand.Read(buf); nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
badid := base64.URLEncoding.EncodeToString(buf)
|
||||
|
||||
_, err := ValidateSessionID(badid, testSigningKey)
|
||||
if err == nil {
|
||||
t.Error("Able to validate bad key")
|
||||
}
|
||||
}
|
||||
32
internal/session/store.go
Normal file
32
internal/session/store.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultSessionDuration is the default duration for
|
||||
// saving session data in the store. Most Store implementations
|
||||
// will automatically delete saved session data after this time.
|
||||
const DefaultSessionDuration = time.Hour
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("sessin not found or expired")
|
||||
ErrSessionExpired = errors.New("session expired")
|
||||
)
|
||||
|
||||
// Store represents a session data store.
|
||||
// This is an abstract interface that can be implemented
|
||||
// against several different types of data stores. For example,
|
||||
// session data could be stored in memory in a concurrent map,
|
||||
// or more typically in a shared key/value server store like redis.
|
||||
type Store interface {
|
||||
GetSession(sid string) (*Session, error)
|
||||
SetSession(sid string, sess *Session) error
|
||||
HasSession(sid string) bool
|
||||
DelSession(sid string) error
|
||||
|
||||
SyncSession(sess *Session) error
|
||||
|
||||
GetAllSessions() ([]*Session, error)
|
||||
}
|
||||
Reference in New Issue
Block a user