Initial Codebase (untested)

This commit is contained in:
James Mills
2021-01-30 14:05:04 +10:00
parent c1dc91b7e0
commit 4529ea3196
60 changed files with 9807 additions and 0 deletions

166
internal/session/manager.go Normal file
View File

@@ -0,0 +1,166 @@
package session
import (
"context"
"net/http"
"time"
"github.com/andreadipersio/securecookie"
log "github.com/sirupsen/logrus"
)
// Key ...
type Key int
const (
SessionKey Key = iota
)
// Options ...
type Options struct {
name string
secret string
secure bool
expiry time.Duration
}
// NewOptions ...
func NewOptions(name, secret string, secure bool, expiry time.Duration) *Options {
return &Options{name, secret, secure, expiry}
}
// Manager ...
type Manager struct {
options *Options
store Store
}
// NewManager ...
func NewManager(options *Options, store Store) *Manager {
return &Manager{options, store}
}
// Create ...
func (m *Manager) Create(w http.ResponseWriter) (*Session, error) {
sid, err := NewSessionID(m.options.secret)
if err != nil {
log.WithError(err).Error("error creating new session")
return nil, err
}
cookie := &http.Cookie{
Name: m.options.name,
Value: sid.String(),
Path: "/",
Secure: m.options.secure,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(m.options.expiry.Seconds()),
Expires: time.Now().Add(m.options.expiry),
}
securecookie.SetSecureCookie(w, m.options.secret, cookie)
return &Session{
store: m.store,
ID: sid.String(),
Data: make(Map),
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(m.options.expiry),
}, nil
}
// Validate ....
func (m *Manager) Validate(value string) (ID, error) {
sessionID, err := ValidateSessionID(value, m.options.secret)
return sessionID, err
}
// GetOrCreate ...
func (m *Manager) GetOrCreate(w http.ResponseWriter, r *http.Request) (*Session, error) {
cookie, err := securecookie.GetSecureCookie(
r,
m.options.secret,
m.options.name,
)
if err != nil {
sess, err := m.Create(w)
if err != nil {
log.WithError(err).Error("error creating new session")
return nil, err
}
if err = m.store.SetSession(sess.ID, sess); err != nil {
log.WithError(err).Errorf("error creating new session for %s", sess.ID)
return nil, err
}
return sess, nil
}
sid, err := m.Validate(cookie.Value)
if err != nil {
log.WithError(err).Error("error validating seesion")
return nil, err
}
sess, err := m.store.GetSession(sid.String())
if err != nil {
if err == ErrSessionNotFound {
log.WithError(err).Warnf("no session found for %s (creating new one)", sid)
m.Delete(w, r)
sess, err := m.Create(w)
if err != nil {
log.WithError(err).Error("error creating new session")
return nil, err
}
if err = m.store.SetSession(sess.ID, sess); err != nil {
log.WithError(err).Errorf("error creating new session for %s", sess.ID)
return nil, err
}
return sess, nil
}
log.WithError(err).Errorf("error loading session for %s", sid)
return nil, err
}
return sess, nil
}
// Delete ...
func (m *Manager) Delete(w http.ResponseWriter, r *http.Request) {
if sess := r.Context().Value(SessionKey); sess != nil {
sess := sess.(*Session)
if err := m.store.DelSession(sess.ID); err != nil {
log.WithError(err).Warnf("error deleting session %s", sess.ID)
}
}
cookie := &http.Cookie{
Name: m.options.name,
Value: "",
Secure: m.options.secure,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
MaxAge: -1,
Expires: time.Now(),
}
securecookie.SetSecureCookie(w, m.options.secret, cookie)
}
// Handler ...
func (m *Manager) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sess, err := m.GetOrCreate(w, r)
if err != nil {
log.WithError(err).Error("session error")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
ctx := context.WithValue(r.Context(), SessionKey, sess)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@@ -0,0 +1,67 @@
package session
import (
"time"
"github.com/patrickmn/go-cache"
)
// MemoryStore represents an in-memory session store.
// This should be used only for testing and prototyping.
// Production systems should use a shared server store like redis
type MemoryStore struct {
entries *cache.Cache
}
// NewMemoryStore constructs and returns a new MemoryStore
func NewMemoryStore(sessionDuration time.Duration) *MemoryStore {
if sessionDuration < 0 {
sessionDuration = DefaultSessionDuration
}
return &MemoryStore{
entries: cache.New(sessionDuration, time.Minute),
}
}
// GetSession ...
func (s *MemoryStore) GetSession(sid string) (*Session, error) {
val, found := s.entries.Get(sid)
if !found {
return nil, ErrSessionNotFound
}
sess := val.(*Session)
return sess, nil
}
// SetSession ...
func (s *MemoryStore) SetSession(sid string, sess *Session) error {
s.entries.Set(sid, sess, cache.DefaultExpiration)
return nil
}
// HasSession ...
func (s *MemoryStore) HasSession(sid string) bool {
_, ok := s.entries.Get(sid)
return ok
}
// DelSession ...
func (s *MemoryStore) DelSession(sid string) error {
s.entries.Delete(sid)
return nil
}
// SyncSession ...
func (s *MemoryStore) SyncSession(sess *Session) error {
return nil
}
// GetAllSessions ...
func (s *MemoryStore) GetAllSessions() ([]*Session, error) {
var sessions []*Session
for _, item := range s.entries.Items() {
sess := item.Object.(*Session)
sessions = append(sessions, sess)
}
return sessions, nil
}

View File

@@ -0,0 +1,67 @@
package session
import (
"encoding/json"
"time"
)
// Map ...
type Map map[string]string
// Session ...
type Session struct {
store Store
ID string `json:"id"`
Data Map `json:"data"`
CreatedAt time.Time `json:"created"`
ExpiresAt time.Time `json:"expires"`
}
func NewSession(store Store) *Session {
return &Session{store: store}
}
func LoadSession(data []byte, sess *Session) error {
if err := json.Unmarshal(data, &sess); err != nil {
return err
}
if sess.Data == nil {
sess.Data = make(Map)
}
return nil
}
func (sess *Session) Expired() bool {
return sess.ExpiresAt.Before(time.Now())
}
func (sess *Session) Set(key, val string) error {
sess.Data[key] = val
return sess.store.SyncSession(sess)
}
func (sess *Session) Get(key string) (val string, ok bool) {
val, ok = sess.Data[key]
return
}
func (sess *Session) Has(key string) bool {
_, ok := sess.Data[key]
return ok
}
func (sess *Session) Del(key string) error {
delete(sess.Data, key)
return sess.store.SyncSession(sess)
}
func (sess *Session) Bytes() ([]byte, error) {
data, err := json.Marshal(sess)
if err != nil {
return nil, err
}
return data, nil
}

65
internal/session/sid.go Normal file
View File

@@ -0,0 +1,65 @@
package session
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
)
// InvalidSessionID represents an empty, invalid session ID
const InvalidSessionID ID = ""
const idLength = 32
const signedLength = idLength + sha256.Size
// ID represents a valid, digitally-signed session ID
type ID string
// ErrInvalidID is returned when an invalid session id is passed to ValidateID()
var ErrInvalidID = errors.New("Invalid Session ID")
// NewSessionID creates and returns a new digitally-signed session ID,
// using `signingKey` as the HMAC signing key. An error is returned only
// if there was an error generating random bytes for the session ID
func NewSessionID(signingKey string) (ID, error) {
buf := make([]byte, signedLength)
_, err := rand.Read(buf[:idLength])
if err != nil {
return InvalidSessionID, err
}
mac := hmac.New(sha256.New, []byte(signingKey))
_, _ = mac.Write(buf[:idLength])
sig := mac.Sum(nil)
copy(buf[idLength:], sig)
return ID(base64.URLEncoding.EncodeToString(buf)), nil
}
// ValidateSessionID validates the `id` parameter using the `signingKey`
// and returns an error if invalid, or a SignedID if valid
func ValidateSessionID(id string, signingKey string) (ID, error) {
buf, err := base64.URLEncoding.DecodeString(id)
if err != nil {
return InvalidSessionID, err
}
if len(buf) < signedLength {
return InvalidSessionID, ErrInvalidID
}
mac := hmac.New(sha256.New, []byte(signingKey))
_, _ = mac.Write(buf[:idLength])
messageMAC := mac.Sum(nil)
if !hmac.Equal(messageMAC, buf[idLength:]) {
return InvalidSessionID, ErrInvalidID
}
return ID(id), nil
}
func (sid ID) String() string {
return string(sid)
}

View File

@@ -0,0 +1,75 @@
package session
import (
"crypto/rand"
"encoding/base64"
"testing"
)
const testSigningKey = "a very secret key"
func TestNewID(t *testing.T) {
sid, err := NewSessionID(testSigningKey)
if err != nil {
t.Fatal(err)
}
if 0 == len(sid) {
t.Errorf("Signed ID string was empty")
}
_, err = ValidateSessionID(sid.String(), testSigningKey)
if nil != err {
t.Fatal(err)
}
}
func TestInvalidKey(t *testing.T) {
sid, err := NewSessionID(testSigningKey)
if err != nil {
t.Fatal(err)
}
_, err = ValidateSessionID(sid.String(), "some other signing key")
if nil == err {
t.Errorf("Was able to validate with incorrect signign key")
}
}
func TestModified(t *testing.T) {
sid, err := NewSessionID(testSigningKey)
if err != nil {
t.Fatal(err)
}
runes := []rune(sid.String())
runes[0]++
modsid := string(runes)
_, err = ValidateSessionID(modsid, testSigningKey)
if nil == err {
t.Errorf("Was able to validate modified encoded string")
}
}
func TestEmptyID(t *testing.T) {
_, err := ValidateSessionID("", testSigningKey)
if err == nil {
t.Error("Able to validate empty key")
}
}
func TestBadKey(t *testing.T) {
buf := make([]byte, signedLength)
if _, err := rand.Read(buf); nil != err {
t.Fatal(err)
}
badid := base64.URLEncoding.EncodeToString(buf)
_, err := ValidateSessionID(badid, testSigningKey)
if err == nil {
t.Error("Able to validate bad key")
}
}

32
internal/session/store.go Normal file
View File

@@ -0,0 +1,32 @@
package session
import (
"errors"
"time"
)
// DefaultSessionDuration is the default duration for
// saving session data in the store. Most Store implementations
// will automatically delete saved session data after this time.
const DefaultSessionDuration = time.Hour
var (
ErrSessionNotFound = errors.New("sessin not found or expired")
ErrSessionExpired = errors.New("session expired")
)
// Store represents a session data store.
// This is an abstract interface that can be implemented
// against several different types of data stores. For example,
// session data could be stored in memory in a concurrent map,
// or more typically in a shared key/value server store like redis.
type Store interface {
GetSession(sid string) (*Session, error)
SetSession(sid string, sess *Session) error
HasSession(sid string) bool
DelSession(sid string) error
SyncSession(sess *Session) error
GetAllSessions() ([]*Session, error)
}