diff --git a/storage/store/memory/memory.go b/storage/store/memory/memory.go index 79f4e505..4efb89ab 100644 --- a/storage/store/memory/memory.go +++ b/storage/store/memory/memory.go @@ -49,7 +49,7 @@ func (s *Store) GetAllEndpointStatuses(params *paging.EndpointStatusParams) ([]* pagedEndpointStatuses := make([]*endpoint.Status, 0, len(allStatuses)) for _, v := range allStatuses { if status, ok := v.(*endpoint.Status); ok { - pagedEndpointStatuses = append(pagedEndpointStatuses, ShallowCopyEndpointStatus(status, params)) + pagedEndpointStatuses = append(pagedEndpointStatuses, CopyEndpointStatus(status, params)) } } sort.Slice(pagedEndpointStatuses, func(i, j int) bool { @@ -87,7 +87,7 @@ func (s *Store) GetEndpointStatusByKey(key string, params *paging.EndpointStatus if endpointStatus == nil { return nil, common.ErrEndpointNotFound } - return ShallowCopyEndpointStatus(endpointStatus.(*endpoint.Status), params), nil + return CopyEndpointStatus(endpointStatus.(*endpoint.Status), params), nil } // GetSuiteStatusByKey returns the suite status for a given key diff --git a/storage/store/memory/util.go b/storage/store/memory/util.go index 514fda09..f7f24d24 100644 --- a/storage/store/memory/util.go +++ b/storage/store/memory/util.go @@ -6,35 +6,47 @@ import ( "github.com/TwiN/gatus/v5/storage/store/common/paging" ) -// ShallowCopyEndpointStatus returns a shallow copy of a Status with only the results -// within the range defined by the page and pageSize parameters -func ShallowCopyEndpointStatus(ss *endpoint.Status, params *paging.EndpointStatusParams) *endpoint.Status { - shallowCopy := &endpoint.Status{ +// CopyEndpointStatus returns a safe copy of a Status with only the results +// within the range defined by the page and pageSize parameters. +// This function performs deep copying of slices to prevent race conditions +// when the original slice is modified concurrently. +func CopyEndpointStatus(ss *endpoint.Status, params *paging.EndpointStatusParams) *endpoint.Status { + statusCopy := &endpoint.Status{ Name: ss.Name, Group: ss.Group, Key: ss.Key, Uptime: endpoint.NewUptime(), } if params == nil || (params.ResultsPage == 0 && params.ResultsPageSize == 0 && params.EventsPage == 0 && params.EventsPageSize == 0) { - shallowCopy.Results = ss.Results - shallowCopy.Events = ss.Events + // Deep copy all results to prevent race conditions + statusCopy.Results = make([]*endpoint.Result, len(ss.Results)) + copy(statusCopy.Results, ss.Results) + // Deep copy all events to prevent race conditions + statusCopy.Events = make([]*endpoint.Event, len(ss.Events)) + copy(statusCopy.Events, ss.Events) } else { numberOfResults := len(ss.Results) resultsStart, resultsEnd := getStartAndEndIndex(numberOfResults, params.ResultsPage, params.ResultsPageSize) if resultsStart < 0 || resultsEnd < 0 { - shallowCopy.Results = []*endpoint.Result{} + statusCopy.Results = []*endpoint.Result{} } else { - shallowCopy.Results = ss.Results[resultsStart:resultsEnd] + // Deep copy the slice range to prevent race conditions + resultRange := ss.Results[resultsStart:resultsEnd] + statusCopy.Results = make([]*endpoint.Result, len(resultRange)) + copy(statusCopy.Results, resultRange) } numberOfEvents := len(ss.Events) eventsStart, eventsEnd := getStartAndEndIndex(numberOfEvents, params.EventsPage, params.EventsPageSize) if eventsStart < 0 || eventsEnd < 0 { - shallowCopy.Events = []*endpoint.Event{} + statusCopy.Events = []*endpoint.Event{} } else { - shallowCopy.Events = ss.Events[eventsStart:eventsEnd] + // Deep copy the slice range to prevent race conditions + eventRange := ss.Events[eventsStart:eventsEnd] + statusCopy.Events = make([]*endpoint.Event, len(eventRange)) + copy(statusCopy.Events, eventRange) } } - return shallowCopy + return statusCopy } // ShallowCopySuiteStatus returns a shallow copy of a suite Status with only the results diff --git a/storage/store/memory/util_bench_test.go b/storage/store/memory/util_bench_test.go index d8f1d26a..018e12ed 100644 --- a/storage/store/memory/util_bench_test.go +++ b/storage/store/memory/util_bench_test.go @@ -8,14 +8,14 @@ import ( "github.com/TwiN/gatus/v5/storage/store/common/paging" ) -func BenchmarkShallowCopyEndpointStatus(b *testing.B) { +func BenchmarkCopyEndpointStatus(b *testing.B) { ep := &testEndpoint status := endpoint.NewStatus(ep.Group, ep.Name) for i := 0; i < storage.DefaultMaximumNumberOfResults; i++ { AddResult(status, &testSuccessfulResult, storage.DefaultMaximumNumberOfResults, storage.DefaultMaximumNumberOfEvents) } for n := 0; n < b.N; n++ { - ShallowCopyEndpointStatus(status, paging.NewEndpointStatusParams().WithResults(1, 20)) + CopyEndpointStatus(status, paging.NewEndpointStatusParams().WithResults(1, 20)) } b.ReportAllocs() } diff --git a/storage/store/memory/util_test.go b/storage/store/memory/util_test.go index 927ca937..2a24d37f 100644 --- a/storage/store/memory/util_test.go +++ b/storage/store/memory/util_test.go @@ -26,7 +26,7 @@ func TestAddResult(t *testing.T) { AddResult(nil, &endpoint.Result{Timestamp: time.Now()}, storage.DefaultMaximumNumberOfResults, storage.DefaultMaximumNumberOfEvents) } -func TestShallowCopyEndpointStatus(t *testing.T) { +func TestCopyEndpointStatus(t *testing.T) { ep := &endpoint.Endpoint{Name: "name", Group: "group"} endpointStatus := endpoint.NewStatus(ep.Group, ep.Name) ts := time.Now().Add(-25 * time.Hour) @@ -34,34 +34,34 @@ func TestShallowCopyEndpointStatus(t *testing.T) { AddResult(endpointStatus, &endpoint.Result{Success: i%2 == 0, Timestamp: ts}, storage.DefaultMaximumNumberOfResults, storage.DefaultMaximumNumberOfEvents) ts = ts.Add(time.Hour) } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(-1, -1)).Results) != 0 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(-1, -1)).Results) != 0 { t.Error("expected to have 0 result") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, 1)).Results) != 1 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, 1)).Results) != 1 { t.Error("expected to have 1 result") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(5, 0)).Results) != 0 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(5, 0)).Results) != 0 { t.Error("expected to have 0 results") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(-1, 20)).Results) != 0 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(-1, 20)).Results) != 0 { t.Error("expected to have 0 result, because the page was invalid") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, -1)).Results) != 0 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, -1)).Results) != 0 { t.Error("expected to have 0 result, because the page size was invalid") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, 10)).Results) != 10 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, 10)).Results) != 10 { t.Error("expected to have 10 results, because given a page size of 10, page 1 should have 10 elements") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(2, 10)).Results) != 10 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(2, 10)).Results) != 10 { t.Error("expected to have 10 results, because given a page size of 10, page 2 should have 10 elements") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(3, 10)).Results) != 5 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(3, 10)).Results) != 5 { t.Error("expected to have 5 results, because given a page size of 10, page 3 should have 5 elements") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(4, 10)).Results) != 0 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(4, 10)).Results) != 0 { t.Error("expected to have 0 results, because given a page size of 10, page 4 should have 0 elements") } - if len(ShallowCopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, 50)).Results) != 25 { + if len(CopyEndpointStatus(endpointStatus, paging.NewEndpointStatusParams().WithResults(1, 50)).Results) != 25 { t.Error("expected to have 25 results, because there's only 25 results") } }