feat(websocket): add support for custom headers in WS check (#1180)

feat(websocket): add support for custom headers in QueryWebSocket function
This commit is contained in:
Olexandr Dovgun
2025-08-02 21:06:46 +03:00
committed by GitHub
parent 8c5ad54e71
commit d27c63ded7
4 changed files with 22 additions and 5 deletions

View File

@@ -2183,7 +2183,7 @@ This works for SCTP based application.
### Monitoring a WebSocket endpoint
By prefixing `endpoints[].url` with `ws://` or `wss://`, you can monitor WebSocket endpoints at a very basic level:
By prefixing `endpoints[].url` with `ws://` or `wss://`, you can monitor WebSocket endpoints:
```yaml
endpoints:
- name: example

View File

@@ -321,7 +321,7 @@ func Ping(address string, config *Config) (bool, time.Duration) {
}
// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
func QueryWebSocket(address, body string, config *Config) (bool, []byte, error) {
func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) {
const (
Origin = "http://localhost/"
MaximumMessageSize = 1024 // in bytes
@@ -330,6 +330,14 @@ func QueryWebSocket(address, body string, config *Config) (bool, []byte, error)
if err != nil {
return false, nil, fmt.Errorf("error configuring websocket connection: %w", err)
}
if headers != nil {
if wsConfig.Header == nil {
wsConfig.Header = make(http.Header)
}
for name, value := range headers {
wsConfig.Header.Set(name, value)
}
}
if config != nil {
wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout}
wsConfig.TlsConfig = &tls.Config{

View File

@@ -305,11 +305,11 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
}
func TestQueryWebSocket(t *testing.T) {
_, _, err := QueryWebSocket("", "body", &Config{Timeout: 2 * time.Second})
_, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second})
if err == nil {
t.Error("expected an error due to the address being invalid")
}
_, _, err = QueryWebSocket("ws://example.org", "body", &Config{Timeout: 2 * time.Second})
_, _, err = QueryWebSocket("ws://example.org", "body", nil, &Config{Timeout: 2 * time.Second})
if err == nil {
t.Error("expected an error due to the target not being websocket-friendly")
}

View File

@@ -399,7 +399,16 @@ func (e *Endpoint) call(result *Result) {
} else if endpointType == TypeICMP {
result.Connected, result.Duration = client.Ping(strings.TrimPrefix(e.URL, "icmp://"), e.ClientConfig)
} else if endpointType == TypeWS {
result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), e.ClientConfig)
wsHeaders := map[string]string{}
if e.Headers != nil {
for k, v := range e.Headers {
wsHeaders[k] = v
}
}
if _, exists := wsHeaders["User-Agent"]; !exists {
wsHeaders["User-Agent"] = GatusUserAgent
}
result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), wsHeaders, e.ClientConfig)
if err != nil {
result.AddError(err.Error())
return