feat(websocket): add support for custom headers in WS check (#1180)
feat(websocket): add support for custom headers in QueryWebSocket function
This commit is contained in:
@@ -2183,7 +2183,7 @@ This works for SCTP based application.
|
||||
|
||||
|
||||
### Monitoring a WebSocket endpoint
|
||||
By prefixing `endpoints[].url` with `ws://` or `wss://`, you can monitor WebSocket endpoints at a very basic level:
|
||||
By prefixing `endpoints[].url` with `ws://` or `wss://`, you can monitor WebSocket endpoints:
|
||||
```yaml
|
||||
endpoints:
|
||||
- name: example
|
||||
|
||||
@@ -321,7 +321,7 @@ func Ping(address string, config *Config) (bool, time.Duration) {
|
||||
}
|
||||
|
||||
// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
|
||||
func QueryWebSocket(address, body string, config *Config) (bool, []byte, error) {
|
||||
func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) {
|
||||
const (
|
||||
Origin = "http://localhost/"
|
||||
MaximumMessageSize = 1024 // in bytes
|
||||
@@ -330,6 +330,14 @@ func QueryWebSocket(address, body string, config *Config) (bool, []byte, error)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("error configuring websocket connection: %w", err)
|
||||
}
|
||||
if headers != nil {
|
||||
if wsConfig.Header == nil {
|
||||
wsConfig.Header = make(http.Header)
|
||||
}
|
||||
for name, value := range headers {
|
||||
wsConfig.Header.Set(name, value)
|
||||
}
|
||||
}
|
||||
if config != nil {
|
||||
wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout}
|
||||
wsConfig.TlsConfig = &tls.Config{
|
||||
|
||||
@@ -305,11 +305,11 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryWebSocket(t *testing.T) {
|
||||
_, _, err := QueryWebSocket("", "body", &Config{Timeout: 2 * time.Second})
|
||||
_, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second})
|
||||
if err == nil {
|
||||
t.Error("expected an error due to the address being invalid")
|
||||
}
|
||||
_, _, err = QueryWebSocket("ws://example.org", "body", &Config{Timeout: 2 * time.Second})
|
||||
_, _, err = QueryWebSocket("ws://example.org", "body", nil, &Config{Timeout: 2 * time.Second})
|
||||
if err == nil {
|
||||
t.Error("expected an error due to the target not being websocket-friendly")
|
||||
}
|
||||
|
||||
@@ -399,7 +399,16 @@ func (e *Endpoint) call(result *Result) {
|
||||
} else if endpointType == TypeICMP {
|
||||
result.Connected, result.Duration = client.Ping(strings.TrimPrefix(e.URL, "icmp://"), e.ClientConfig)
|
||||
} else if endpointType == TypeWS {
|
||||
result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), e.ClientConfig)
|
||||
wsHeaders := map[string]string{}
|
||||
if e.Headers != nil {
|
||||
for k, v := range e.Headers {
|
||||
wsHeaders[k] = v
|
||||
}
|
||||
}
|
||||
if _, exists := wsHeaders["User-Agent"]; !exists {
|
||||
wsHeaders["User-Agent"] = GatusUserAgent
|
||||
}
|
||||
result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), wsHeaders, e.ClientConfig)
|
||||
if err != nil {
|
||||
result.AddError(err.Error())
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user