diff --git a/go.mod b/go.mod index 845dcd4..84dad01 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/sirupsen/logrus v1.7.0 github.com/spf13/pflag v1.0.5 github.com/steambap/captcha v1.3.1 + github.com/stretchr/testify v1.6.1 github.com/temoto/robotstxt v1.1.1 // indirect github.com/unrolled/logger v0.0.0-20201216141554-31a3694fe979 github.com/vcraescu/go-paginator v1.0.0 diff --git a/go.sum b/go.sum index 3e47cc3..d6ed2e7 100644 --- a/go.sum +++ b/go.sum @@ -133,6 +133,7 @@ github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0 github.com/daaku/go.zipexe v1.0.0 h1:VSOgZtH418pH9L16hC/JrgSNJbbAL26pj7lmD1+CGdY= github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CLnGVd57E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.9.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= @@ -349,6 +350,7 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/plar/go-adaptive-radix-tree v1.0.4 h1:Ucd8R6RH2E7RW8ZtDKrsWyOD3paG2qqJO0I20WQ8oWQ= github.com/plar/go-adaptive-radix-tree v1.0.4/go.mod h1:Ot8d28EII3i7Lv4PSvBlF8ejiD/CtRYDuPsySJbSaK8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prologic/bitcask v0.3.10 h1:HXygU8zCvW5gLpZ8aQECPk5iV/YQ3hcqdg/zVeES6s0= @@ -432,6 +434,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= @@ -667,6 +670,7 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/sqlite v1.1.3/go.mod h1:AKDgRWk8lcSQSw+9kxCJnX/yySj8G3rdwYlU57cB45c= gorm.io/gorm v1.20.1/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= diff --git a/internal/base_task.go b/internal/base_task.go new file mode 100644 index 0000000..301a01f --- /dev/null +++ b/internal/base_task.go @@ -0,0 +1,64 @@ +package internal + +import ( + "fmt" + + "github.com/renstrom/shortuuid" +) + +type BaseTask struct { + state TaskState + data TaskData + err error + id string +} + +func NewBaseTask() *BaseTask { + return &BaseTask{ + data: make(TaskData), + id: shortuuid.New(), + } +} + +func (t *BaseTask) SetState(state TaskState) { + t.state = state +} + +func (t *BaseTask) SetData(key, val string) { + if t.data == nil { + t.data = make(TaskData) + } + t.data[key] = val +} + +func (t *BaseTask) Done() { + if t.err != nil { + t.state = TaskStateFailed + } else { + t.state = TaskStateComplete + } +} + +func (t *BaseTask) Fail(err error) error { + t.err = err + return err +} + +func (t *BaseTask) Result() TaskResult { + stateStr := t.state.String() + errStr := "" + if t.err != nil { + errStr = t.err.Error() + } + + return TaskResult{ + State: stateStr, + Error: errStr, + Data: t.data, + } +} + +func (t *BaseTask) String() string { return fmt.Sprintf("%T: %s", t, t.ID()) } +func (t *BaseTask) ID() string { return t.id } +func (t *BaseTask) State() TaskState { return t.state } +func (t *BaseTask) Error() error { return t.err } diff --git a/internal/crawl_task.go b/internal/crawl_task.go new file mode 100644 index 0000000..732f86e --- /dev/null +++ b/internal/crawl_task.go @@ -0,0 +1,81 @@ +package internal + +import ( + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +type CrawlTask struct { + *BaseTask + + conf *Config + db Store + indexer Indexer + + url string +} + +func NewCrawlTask(conf *Config, db Store, indexer Indexer, url string) *CrawlTask { + return &CrawlTask{ + BaseTask: NewBaseTask(), + + conf: conf, + db: db, + indexer: indexer, + + url: url, + } +} + +func (t *CrawlTask) String() string { return fmt.Sprintf("%T: %s", t, t.ID()) } +func (t *CrawlTask) Run() error { + defer t.Done() + t.SetState(TaskStateRunning) + + log.Infof("starting crawl task for %s", t.url) + + log.Debugf("crawling %s", t.url) + + links, err := GetLinks(t.url) + if err != nil { + log.WithError(err).Error("error crawling %s", t.url) + return t.Fail(fmt.Errorf("error crawling %s: %w", t.url, err)) + } + + for link := range links { + hash := HashURL(link) + + if t.db.HasURL(hash) { + log.Debugf("seen %s (skipping)", link) + return nil + } + + log.Debugf("found %s", link) + + metrics.Counter("crawler", "crawled").Inc() + + url := NewURL(link) + url.CrawledAt = time.Now() + + entry, err := Scrape(t.conf, link) + if err != nil { + log.WithError(err).Warn("error scraping %s", link) + continue + } + + if err := t.indexer.Index(entry); err != nil { + log.WithError(err).Warn("error indexing %s", link) + continue + } + + if err := t.db.SetURL(hash, url); err != nil { + log.WithError(err).Warn("error recording url %s", link) + } + + metrics.Counter("crawler", "scraped").Inc() + } + + return nil +} diff --git a/internal/crawler.go b/internal/crawler.go index ce8fc2c..0779364 100644 --- a/internal/crawler.go +++ b/internal/crawler.go @@ -1,71 +1,48 @@ package internal import ( - "time" - log "github.com/sirupsen/logrus" ) type Crawler interface { Start() + Stop() Crawl(url string) error } type crawler struct { conf *Config + tasks *Dispatcher db Store - queue chan string indexer Indexer + queue chan string } -func NewCrawler(conf *Config, db Store, indexer Indexer) (Crawler, error) { +func NewCrawler(conf *Config, tasks *Dispatcher, db Store, indexer Indexer) (Crawler, error) { return &crawler{ conf: conf, + tasks: tasks, db: db, - queue: make(chan string), indexer: indexer, + queue: make(chan string), }, nil } func (c *crawler) loop() { for { - url := <-c.queue - log.Debugf("crawling %s", url) - - links, err := GetLinks(url) - if err != nil { - log.WithError(err).Error("error crawling %s", url) - continue - } - - for link := range links { - hash := HashURL(link) - - if c.db.HasURL(hash) { - log.Debugf("seen %s (skipping)", link) - continue + select { + case url, ok := <-c.queue: + if !ok { + log.Debugf("crawler shutting down...") + return } - - log.Debugf("found %s", link) - - metrics.Counter("crawler", "crawled").Inc() - - url := NewURL(link) - url.CrawledAt = time.Now() - - entry, err := Scrape(c.conf, link) + log.Debugf("crawling %s", url) + uuid, err := c.tasks.Dispatch(NewCrawlTask(c.conf, c.db, c.indexer, url)) if err != nil { - log.WithError(err).Error("error scraping %s", link) + log.WithError(err).Error("error creating crawl task for %s", url) } else { - if err := c.indexer.Index(entry); err != nil { - log.WithError(err).Error("error indexing %s", link) - } else { - if err := c.db.SetURL(hash, url); err != nil { - log.WithError(err).Error("error recording url %s", link) - } else { - metrics.Counter("crawler", "scraped").Inc() - } - } + taskURL := URLForTask(c.conf.BaseURL, uuid) + log.WithField("uuid", uuid).Infof("successfully created crawl task for %s: %s", url, taskURL) } } } @@ -76,6 +53,10 @@ func (c *crawler) Crawl(url string) error { return nil } +func (c *crawler) Stop() { + close(c.queue) +} + func (c *crawler) Start() { go c.loop() } diff --git a/internal/dispatcher.go b/internal/dispatcher.go new file mode 100644 index 0000000..51c6f98 --- /dev/null +++ b/internal/dispatcher.go @@ -0,0 +1,102 @@ +package internal + +import ( + "errors" +) + +// Dispatcher maintains a pool for available workers +// and a task queue that workers will process +type Dispatcher struct { + maxWorkers int + maxQueue int + workers []*Worker + workerPool chan chan Task + taskQueue chan Task + taskMap map[string]Task + quit chan bool + active bool +} + +// NewDispatcher creates a new dispatcher with the given +// number of workers and buffers the task queue based on maxQueue. +// It also initializes the channels for the worker pool and task queue +func NewDispatcher(maxWorkers int, maxQueue int) *Dispatcher { + return &Dispatcher{ + maxWorkers: maxWorkers, + maxQueue: maxQueue, + } +} + +// Start creates and starts workers, adding them to the worker pool. +// Then, it starts a select loop to wait for tasks to be dispatched +// to available workers +func (d *Dispatcher) Start() { + d.workers = []*Worker{} + d.workerPool = make(chan chan Task, d.maxWorkers) + d.taskQueue = make(chan Task, d.maxQueue) + d.taskMap = make(map[string]Task) + d.quit = make(chan bool) + + for i := 0; i < d.maxWorkers; i++ { + worker := NewWorker(d.workerPool) + worker.Start() + d.workers = append(d.workers, worker) + } + + d.active = true + + go func() { + for { + select { + case task := <-d.taskQueue: + go func(task Task) { + taskChannel := <-d.workerPool + taskChannel <- task + }(task) + case <-d.quit: + return + } + } + }() +} + +// Stop ends execution for all workers and closes all channels, then removes +// all workers +func (d *Dispatcher) Stop() { + if !d.active { + return + } + + d.active = false + + for i := range d.workers { + d.workers[i].Stop() + } + + d.workers = []*Worker{} + d.quit <- true +} + +// Lookup returns the matching `Task` given its id +func (d *Dispatcher) Lookup(id string) (Task, bool) { + task, ok := d.taskMap[id] + return task, ok +} + +// Dispatch pushes the given task into the task queue. +// The first available worker will perform the task +func (d *Dispatcher) Dispatch(task Task) (string, error) { + if !d.active { + return "", errors.New("dispatcher is not active") + } + + d.taskQueue <- task + d.taskMap[task.ID()] = task + return task.ID(), nil +} + +// DispatchFunc pushes the given func into the task queue by first wrapping +// it with a `TaskFunc` task. +func (d *Dispatcher) DispatchFunc(f func() error) (string, error) { + return d.Dispatch(NewFuncTask(f)) +} diff --git a/internal/dispatcher_test.go b/internal/dispatcher_test.go new file mode 100644 index 0000000..aa6f88f --- /dev/null +++ b/internal/dispatcher_test.go @@ -0,0 +1,110 @@ +package internal + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDispatcher_Dispatch(t *testing.T) { + a := 0 + aMu := sync.RWMutex{} + + b := 0 + bMu := sync.RWMutex{} + + c := 0 + cMu := sync.RWMutex{} + + d := NewDispatcher(10, 3) + d.Start() + + _, _ = d.DispatchFunc(func() error { + aMu.Lock() + a = 1 + aMu.Unlock() + return nil + }) + + _, _ = d.DispatchFunc(func() error { + bMu.Lock() + b = 2 + bMu.Unlock() + return nil + }) + + _, _ = d.DispatchFunc(func() error { + cMu.Lock() + c = 3 + cMu.Unlock() + return nil + }) + + time.Sleep(time.Second) + + aMu.RLock() + assert.Equal(t, 1, a) + aMu.RUnlock() + + bMu.RLock() + assert.Equal(t, 2, b) + bMu.RUnlock() + + cMu.RLock() + assert.Equal(t, 3, c) + cMu.RUnlock() +} + +func TestDispatcher_Dispatch_Mutex(t *testing.T) { + n := 100 + mu := &sync.RWMutex{} + + d := NewDispatcher(10, n) + d.Start() + + var v []int + + for i := 0; i < n; i++ { + _, _ = d.DispatchFunc(func() error { + mu.Lock() + v = append(v, 0) + mu.Unlock() + return nil + }) + } + + time.Sleep(time.Second) + + mu.RLock() + assert.Equal(t, n, len(v)) + mu.RUnlock() +} + +func TestDispatcher_Stop(t *testing.T) { + c := 0 + mu := sync.RWMutex{} + + d := NewDispatcher(1, 3) + d.Start() + + _, _ = d.DispatchFunc(func() error { + mu.Lock() + c++ + mu.Unlock() + return nil + }) + + time.Sleep(time.Millisecond * 100) + d.Stop() + time.Sleep(time.Millisecond * 100) + + _, err := d.DispatchFunc(func() error { + mu.Lock() + c++ + mu.Unlock() + return nil + }) + assert.NotNil(t, err) +} diff --git a/internal/func_task.go b/internal/func_task.go new file mode 100644 index 0000000..910b401 --- /dev/null +++ b/internal/func_task.go @@ -0,0 +1,25 @@ +package internal + +import "fmt" + +type FuncTask struct { + *BaseTask + + f func() error +} + +func NewFuncTask(f func() error) *FuncTask { + return &FuncTask{ + BaseTask: NewBaseTask(), + + f: f, + } +} + +func (t *FuncTask) String() string { return fmt.Sprintf("%T: %s", t, t.ID()) } +func (t *FuncTask) Run() error { + defer t.Done() + t.SetState(TaskStateRunning) + + return t.f() +} diff --git a/internal/server.go b/internal/server.go index 894e0ae..5fc699d 100644 --- a/internal/server.go +++ b/internal/server.go @@ -53,6 +53,9 @@ type Server struct { // Scheduler cron *cron.Cron + // Dispatcher + tasks *Dispatcher + // Auth am *auth.Manager @@ -93,6 +96,8 @@ func (s *Server) AddShutdownHook(f func()) { // Shutdown ... func (s *Server) Shutdown(ctx context.Context) error { s.cron.Stop() + s.tasks.Stop() + s.crawler.Stop() if err := s.server.Shutdown(ctx); err != nil { log.WithError(err).Error("error shutting down server") @@ -300,6 +305,9 @@ func (s *Server) initRoutes() { s.router.GET("/chpasswd", s.ResetPasswordMagicLinkHandler()) s.router.POST("/chpasswd", s.NewPasswordHandler()) + // Task State + s.router.GET("/task/:uuid", s.TaskHandler()) + s.router.GET("/add", s.AddHandler()) s.router.POST("/add", s.AddHandler()) @@ -362,22 +370,12 @@ func NewServer(bind string, options ...Option) (*Server, error) { return nil, err } - indexer, err := NewIndexer(config) - if err != nil { - log.WithError(err).Error("error creating indexer") - return nil, err - } - - crawler, err := NewCrawler(config, db, indexer) - if err != nil { - log.WithError(err).Error("error creating crawler") - return nil, err - } - router := NewRouter() am := auth.NewManager(auth.NewOptions("/login", "/register")) + tasks := NewDispatcher(10, 100) // TODO: Make this configurable? + pm := passwords.NewScryptPasswords(nil) sc := NewSessionStore(db, config.SessionCacheTTL) @@ -392,6 +390,18 @@ func NewServer(bind string, options ...Option) (*Server, error) { sc, ) + indexer, err := NewIndexer(config) + if err != nil { + log.WithError(err).Error("error creating indexer") + return nil, err + } + + crawler, err := NewCrawler(config, tasks, db, indexer) + if err != nil { + log.WithError(err).Error("error creating crawler") + return nil, err + } + api := NewAPI(router, config, db, pm) csrfHandler := nosurf.New(router) @@ -430,6 +440,9 @@ func NewServer(bind string, options ...Option) (*Server, error) { // Schedular cron: cron.New(), + // Dispatcher + tasks: tasks, + // Auth Manager am: am, @@ -448,6 +461,9 @@ func NewServer(bind string, options ...Option) (*Server, error) { server.cron.Start() log.Info("started background jobs") + server.tasks.Start() + log.Info("started task dispatcher") + server.crawler.Start() log.Infof("started crawler") diff --git a/internal/task.go b/internal/task.go new file mode 100644 index 0000000..79200c0 --- /dev/null +++ b/internal/task.go @@ -0,0 +1,47 @@ +package internal + +import "fmt" + +type TaskState int + +const ( + TaskStatePending TaskState = iota + TaskStateRunning + TaskStateComplete + TaskStateFailed +) + +func (t TaskState) String() string { + switch t { + case TaskStatePending: + return "pending" + case TaskStateRunning: + return "running" + case TaskStateComplete: + return "complete" + case TaskStateFailed: + return "failed" + default: + return "unknown" + } +} + +type TaskData map[string]string + +type TaskResult struct { + State string `json:"state"` + Error string `json:"error"` + Data TaskData `json:"data"` +} + +// Task is an interface that represents a single task to be executed by a +// worker. Any object can implement a `Task` if it implements the interface. +type Task interface { + fmt.Stringer + + ID() string + State() TaskState + Result() TaskResult + Error() error + Run() error +} diff --git a/internal/task_handler.go b/internal/task_handler.go new file mode 100644 index 0000000..65bf75e --- /dev/null +++ b/internal/task_handler.go @@ -0,0 +1,39 @@ +package internal + +import ( + "encoding/json" + "net/http" + + "github.com/julienschmidt/httprouter" + log "github.com/sirupsen/logrus" +) + +// TaskHandler ... +func (s *Server) TaskHandler() httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + uuid := p.ByName("uuid") + + if uuid == "" { + log.Warn("no task uuid provided") + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + t, ok := s.tasks.Lookup(uuid) + if !ok { + log.Warnf("no task found by uuid: %s", uuid) + http.Error(w, "Task Not Found", http.StatusNotFound) + return + } + + data, err := json.Marshal(t.Result()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(data) + + } +} diff --git a/internal/utils.go b/internal/utils.go index 36362d7..fed3761 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -381,6 +381,14 @@ func URLForCached(baseURL, hash string) string { ) } +func URLForTask(baseURL, uuid string) string { + return fmt.Sprintf( + "%s/task/%s", + strings.TrimSuffix(baseURL, "/"), + uuid, + ) +} + // SafeParseInt ... func SafeParseInt(s string, d int) int { n, e := strconv.Atoi(s) diff --git a/internal/worker.go b/internal/worker.go new file mode 100644 index 0000000..aef0568 --- /dev/null +++ b/internal/worker.go @@ -0,0 +1,49 @@ +package internal + +import ( + log "github.com/sirupsen/logrus" +) + +// Worker attaches to a provided worker pool, and +// looks for tasks on its task channel +type Worker struct { + workerPool chan chan Task + taskChannel chan Task + quit chan bool +} + +// NewWorker creates a new worker using the given id and +// attaches to the provided worker pool. It also initializes +// the task/quit channels +func NewWorker(workerPool chan chan Task) *Worker { + return &Worker{ + workerPool: workerPool, + taskChannel: make(chan Task), + quit: make(chan bool), + } +} + +// Start initializes a select loop to listen for tasks to execute +func (w *Worker) Start() { + go func() { + for { + w.workerPool <- w.taskChannel + + select { + case task := <-w.taskChannel: + if err := task.Run(); err != nil { + log.WithError(err).Errorf("error running task %s", task) + } + case <-w.quit: + return + } + } + }() +} + +// Stop will end the task select loop for the worker +func (w *Worker) Stop() { + go func() { + w.quit <- true + }() +}