diff --git a/README.md b/README.md index 32d3332b..fdf5202d 100644 --- a/README.md +++ b/README.md @@ -2183,7 +2183,7 @@ This works for SCTP based application. ### Monitoring a WebSocket endpoint -By prefixing `endpoints[].url` with `ws://` or `wss://`, you can monitor WebSocket endpoints at a very basic level: +By prefixing `endpoints[].url` with `ws://` or `wss://`, you can monitor WebSocket endpoints: ```yaml endpoints: - name: example diff --git a/client/client.go b/client/client.go index 4da16989..bc8518e4 100644 --- a/client/client.go +++ b/client/client.go @@ -321,7 +321,7 @@ func Ping(address string, config *Config) (bool, time.Duration) { } // QueryWebSocket opens a websocket connection, write `body` and return a message from the server -func QueryWebSocket(address, body string, config *Config) (bool, []byte, error) { +func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) { const ( Origin = "http://localhost/" MaximumMessageSize = 1024 // in bytes @@ -330,6 +330,14 @@ func QueryWebSocket(address, body string, config *Config) (bool, []byte, error) if err != nil { return false, nil, fmt.Errorf("error configuring websocket connection: %w", err) } + if headers != nil { + if wsConfig.Header == nil { + wsConfig.Header = make(http.Header) + } + for name, value := range headers { + wsConfig.Header.Set(name, value) + } + } if config != nil { wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout} wsConfig.TlsConfig = &tls.Config{ diff --git a/client/client_test.go b/client/client_test.go index f91fd461..2e84031a 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -305,11 +305,11 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) { } func TestQueryWebSocket(t *testing.T) { - _, _, err := QueryWebSocket("", "body", &Config{Timeout: 2 * time.Second}) + _, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second}) if err == nil { t.Error("expected an error due to the address being invalid") } - _, _, err = QueryWebSocket("ws://example.org", "body", &Config{Timeout: 2 * time.Second}) + _, _, err = QueryWebSocket("ws://example.org", "body", nil, &Config{Timeout: 2 * time.Second}) if err == nil { t.Error("expected an error due to the target not being websocket-friendly") } diff --git a/config/endpoint/endpoint.go b/config/endpoint/endpoint.go index ab50f683..f72f675f 100644 --- a/config/endpoint/endpoint.go +++ b/config/endpoint/endpoint.go @@ -399,7 +399,16 @@ func (e *Endpoint) call(result *Result) { } else if endpointType == TypeICMP { result.Connected, result.Duration = client.Ping(strings.TrimPrefix(e.URL, "icmp://"), e.ClientConfig) } else if endpointType == TypeWS { - result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), e.ClientConfig) + wsHeaders := map[string]string{} + if e.Headers != nil { + for k, v := range e.Headers { + wsHeaders[k] = v + } + } + if _, exists := wsHeaders["User-Agent"]; !exists { + wsHeaders["User-Agent"] = GatusUserAgent + } + result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), wsHeaders, e.ClientConfig) if err != nil { result.AddError(err.Error()) return